From 7a614ce6973847bd43554c9cddff687f3115beb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 17 Jan 2020 15:26:21 +0100 Subject: [PATCH] Support triangular git workflows in `pr create` - The local git remotes are scanned and resolved to GitHub repositories - The "base" repo is the first result resolved to its parent repo (if a fork) - The name of the default branch is read from the base repo - The "head" repo is the first repo that has push access --- api/client.go | 30 ++++-- api/queries_pr.go | 10 +- api/queries_repo.go | 145 +++++++++++++++++++++++++- command/pr_create.go | 243 +++++++++++++++++++++++++++++++------------ context/remote.go | 20 ++++ 5 files changed, 364 insertions(+), 84 deletions(-) diff --git a/api/client.go b/api/client.go index 2aa82d1ea..f3100815d 100644 --- a/api/client.go +++ b/api/client.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" "net/http" + "strings" ) // ClientOption represents an argument to NewClient @@ -69,9 +70,27 @@ type Client struct { type graphQLResponse struct { Data interface{} - Errors []struct { - Message string + Errors []GraphQLError +} + +// GraphQLError is a single error returned in a GraphQL response +type GraphQLError struct { + Type string + Path []string + Message string +} + +// GraphQLErrorResponse contains errors returned in a GraphQL response +type GraphQLErrorResponse struct { + Errors []GraphQLError +} + +func (gr GraphQLErrorResponse) Error() string { + errorMessages := make([]string, 0, len(gr.Errors)) + for _, e := range gr.Errors { + errorMessages = append(errorMessages, e.Message) } + return fmt.Sprintf("graphql error: '%s'", strings.Join(errorMessages, ", ")) } // GraphQL performs a GraphQL request and parses the response @@ -151,14 +170,9 @@ func handleResponse(resp *http.Response, data interface{}) error { } if len(gr.Errors) > 0 { - errorMessages := gr.Errors[0].Message - for _, e := range gr.Errors[1:] { - errorMessages += ", " + e.Message - } - return fmt.Errorf("graphql error: '%s'", errorMessages) + return &GraphQLErrorResponse{Errors: gr.Errors} } return nil - } func handleHTTPError(resp *http.Response) error { diff --git a/api/queries_pr.go b/api/queries_pr.go index 9603bf0f2..382d4255b 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -403,12 +403,8 @@ func PullRequestForBranch(client *Client, ghRepo Repo, branch string) (*PullRequ return nil, &NotFoundError{fmt.Errorf("no open pull requests found for branch %q", branch)} } -func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) { - repo, err := GitHubRepo(client, ghRepo) - if err != nil { - return nil, err - } - +// CreatePullRequest creates a pull request in a GitHub repository +func CreatePullRequest(client *Client, repo Repository, params map[string]interface{}) (*PullRequest, error) { query := ` mutation CreatePullRequest($input: CreatePullRequestInput!) { createPullRequest(input: $input) { @@ -434,7 +430,7 @@ func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{ } }{} - err = client.GraphQL(query, variables, &result) + err := client.GraphQL(query, variables, &result) if err != nil { return nil, err } diff --git a/api/queries_repo.go b/api/queries_repo.go index f974b41cb..2b34794f7 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -1,15 +1,59 @@ package api import ( + "bytes" + "encoding/json" "fmt" + "sort" + "strings" "github.com/pkg/errors" ) // Repository contains information about a GitHub repo type Repository struct { - ID string + ID string + Name string + Owner struct { + Login string + } + + IsPrivate bool HasIssuesEnabled bool + ViewerPermission string + DefaultBranchRef struct { + Name string + Target struct { + OID string + } + } + + Parent *Repository +} + +// RepoOwner is the login name of the owner +func (r Repository) RepoOwner() string { + return r.Owner.Login +} + +// RepoName is the name of the repository +func (r Repository) RepoName() string { + return r.Name +} + +// IsFork is true when this repository has a parent repository +func (r Repository) IsFork() bool { + return r.Parent != nil +} + +// ViewerCanPush is true when the requesting user has push access +func (r Repository) ViewerCanPush() bool { + switch r.ViewerPermission { + case "ADMIN", "MAINTAIN", "WRITE": + return true + default: + return false + } } // GitHubRepo looks up the node ID of a named repository @@ -44,3 +88,102 @@ func GitHubRepo(client *Client, ghRepo Repo) (*Repository, error) { return &result.Repository, nil } + +// RepoNetworkResult describes the relationship between related repositories +type RepoNetworkResult struct { + ViewerLogin string + Repositories []*Repository +} + +// RepoNetwork inspects the relationship between multiple GitHub repositories +func RepoNetwork(client *Client, repos []Repo) (RepoNetworkResult, error) { + queries := []string{} + for i, repo := range repos { + queries = append(queries, fmt.Sprintf(` + repo_%03d: repository(owner: %q, name: %q) { + ...repo + parent { + ...repo + } + } + `, i, repo.RepoOwner(), repo.RepoName())) + } + + type ViewerOrRepo struct { + Login string + Repository + } + + graphqlResult := map[string]*json.RawMessage{} + result := RepoNetworkResult{} + + err := client.GraphQL(fmt.Sprintf(` + fragment repo on Repository { + id + name + owner { login } + viewerPermission + defaultBranchRef { + name + target { oid } + } + isPrivate + } + query { + viewer { login } + %s + } + `, strings.Join(queries, "")), nil, &graphqlResult) + graphqlError, isGraphQLError := err.(*GraphQLErrorResponse) + if isGraphQLError { + // If the only errors are that certain repositories are not found, + // continue processing this response instead of returning an error + tolerated := true + for _, ge := range graphqlError.Errors { + if ge.Type != "NOT_FOUND" { + tolerated = false + } + } + if tolerated { + err = nil + } + } + if err != nil { + return result, err + } + + keys := []string{} + for key := range graphqlResult { + keys = append(keys, key) + } + // sort keys to ensure `repo_{N}` entries are processed in order + sort.Sort(sort.StringSlice(keys)) + + for _, name := range keys { + jsonMessage := graphqlResult[name] + if name == "viewer" { + viewerResult := struct { + Login string + }{} + decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage))) + if err := decoder.Decode(&viewerResult); err != nil { + return result, err + } + result.ViewerLogin = viewerResult.Login + } else if strings.HasPrefix(name, "repo_") { + if jsonMessage == nil { + result.Repositories = append(result.Repositories, nil) + continue + } + repo := Repository{} + decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage))) + if err := decoder.Decode(&repo); err != nil { + return result, err + } + result.Repositories = append(result.Repositories, &repo) + } else { + return result, fmt.Errorf("unknown GraphQL result key %q", name) + } + } + return result, nil +} diff --git a/command/pr_create.go b/command/pr_create.go index a72dff7ff..48030bc20 100644 --- a/command/pr_create.go +++ b/command/pr_create.go @@ -3,6 +3,8 @@ package command import ( "fmt" "net/url" + "sort" + "strings" "github.com/github/gh-cli/api" "github.com/github/gh-cli/context" @@ -15,42 +17,62 @@ import ( func prCreate(cmd *cobra.Command, _ []string) error { ctx := contextForCommand(cmd) - - ucc, err := git.UncommittedChangeCount() + remotes, err := ctx.Remotes() if err != nil { return err } - if ucc > 0 { + + client, err := apiClientForContext(ctx) + if err != nil { + return errors.Wrap(err, "could not initialize API client") + } + + baseRepoOverride, _ := cmd.Flags().GetString("repo") + repoContext, err := resolveRemotesToRepos(remotes, client, baseRepoOverride) + if err != nil { + return err + } + + baseRepo, err := repoContext.BaseRepo() + if err != nil { + return errors.Wrap(err, "could not determine the base repository") + } + + headBranch, err := ctx.Branch() + if err != nil { + return errors.Wrap(err, "could not determine the current branch") + } + + baseBranch, err := cmd.Flags().GetString("base") + if err != nil { + return err + } + if baseBranch == "" { + baseBranch = baseRepo.DefaultBranchRef.Name + } + + headRepo, err := repoContext.HeadRepo() + if err != nil { + // TODO: auto-fork repository and add new git remote + return errors.Wrap(err, "could not determine the head repository") + } + + if headBranch == baseBranch && isSameRepo(baseRepo, headRepo) { + return fmt.Errorf("must be on a branch named differently than %q", baseBranch) + } + + fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n", utils.Cyan(headBranch), utils.Cyan(baseBranch), baseRepo.RepoOwner(), baseRepo.RepoName()) + + headRemote, err := repoContext.RemoteForRepo(headRepo) + if err != nil { + return errors.Wrap(err, "") + } + + if ucc, err := git.UncommittedChangeCount(); err == nil && ucc > 0 { fmt.Fprintf(cmd.ErrOrStderr(), "Warning: %s\n", utils.Pluralize(ucc, "uncommitted change")) } - - repo, err := ctx.BaseRepo() - if err != nil { - return errors.Wrap(err, "could not determine GitHub repo") - } - - head, err := ctx.Branch() - if err != nil { - return errors.Wrap(err, "could not determine current branch") - } - - remote, err := guessRemote(ctx) - if err != nil { - return err - } - - target, err := cmd.Flags().GetString("base") - if err != nil { - return err - } - if target == "" { - // TODO use default branch - target = "master" - } - - fmt.Fprintf(colorableErr(cmd), "\nCreating pull request for %s into %s in %s/%s\n\n", utils.Cyan(head), utils.Cyan(target), repo.RepoOwner(), repo.RepoName()) - - if err = git.Push(remote, fmt.Sprintf("HEAD:%s", head)); err != nil { + // TODO: respect existing upstream configuration of the current branch + if err = git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil { return err } @@ -59,7 +81,7 @@ func prCreate(cmd *cobra.Command, _ []string) error { return errors.Wrap(err, "could not parse web") } if isWeb { - openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, repo.RepoOwner(), repo.RepoName(), head) + openURL := fmt.Sprintf(`https://github.com/%s/%s/pull/%s`, baseRepo.RepoOwner(), baseRepo.RepoName(), headBranch) fmt.Fprintf(cmd.ErrOrStderr(), "Opening %s in your browser.\n", openURL) return utils.OpenInBrowser(openURL) } @@ -104,20 +126,6 @@ func prCreate(cmd *cobra.Command, _ []string) error { } } - base, err := cmd.Flags().GetString("base") - if err != nil { - return errors.Wrap(err, "could not parse base") - } - if base == "" { - // TODO: use default branch for the repo - base = "master" - } - - client, err := apiClientForContext(ctx) - if err != nil { - return errors.Wrap(err, "could not initialize api client") - } - isDraft, err := cmd.Flags().GetBool("draft") if err != nil { return errors.Wrap(err, "could not parse draft") @@ -128,10 +136,11 @@ func prCreate(cmd *cobra.Command, _ []string) error { "title": title, "body": body, "draft": isDraft, - "baseRefName": base, - "headRefName": head, + "baseRefName": baseBranch, + "headRefName": headBranch, } + repo := api.Repository{} pr, err := api.CreatePullRequest(client, repo, params) if err != nil { return errors.Wrap(err, "failed to create pull request") @@ -141,10 +150,10 @@ func prCreate(cmd *cobra.Command, _ []string) error { } else if action == PreviewAction { openURL := fmt.Sprintf( "https://github.com/%s/%s/compare/%s...%s?expand=1&title=%s&body=%s", - repo.RepoOwner(), - repo.RepoName(), - target, - head, + baseRepo.RepoOwner(), + baseRepo.RepoName(), + baseBranch, + headBranch, url.QueryEscape(title), url.QueryEscape(body), ) @@ -163,22 +172,6 @@ func prCreate(cmd *cobra.Command, _ []string) error { } -func guessRemote(ctx context.Context) (string, error) { - remotes, err := ctx.Remotes() - if err != nil { - return "", errors.Wrap(err, "could not read git remotes") - } - - // TODO: consolidate logic with fsContext.BaseRepo - // TODO: check if the GH repo that the remote points to is writeable - remote, err := remotes.FindByName("upstream", "github", "origin", "*") - if err != nil { - return "", errors.Wrap(err, "could not determine suitable remote") - } - - return remote.Name, nil -} - var prCreateCmd = &cobra.Command{ Use: "create", Short: "Create a pull request", @@ -196,3 +189,117 @@ func init() { "The branch into which you want your code merged") prCreateCmd.Flags().BoolP("web", "w", false, "Open the web browser to create a pull request") } + +const maxRemotesForLookup = 5 + +func resolveRemotesToRepos(remotes context.Remotes, client *api.Client, base string) (resolvedRemotes, error) { + sort.Stable(remotes) + lenRemotesForLookup := len(remotes) + if lenRemotesForLookup > maxRemotesForLookup { + lenRemotesForLookup = maxRemotesForLookup + } + + hasBaseOverride := base != "" + baseOverride := repoFromFullName(base) + foundBaseOverride := false + repos := []api.Repo{} + for _, r := range remotes[:lenRemotesForLookup] { + repos = append(repos, r) + if isSameRepo(r, baseOverride) { + foundBaseOverride = true + } + } + if hasBaseOverride && !foundBaseOverride { + repos = append(repos, baseOverride) + } + + result := resolvedRemotes{remotes: remotes} + if hasBaseOverride { + result.baseOverride = baseOverride + } + networkResult, err := api.RepoNetwork(client, repos) + if err != nil { + return result, err + } + result.network = networkResult + return result, nil +} + +type resolvedRemotes struct { + baseOverride api.Repo + remotes context.Remotes + network api.RepoNetworkResult +} + +// BaseRepo is the first found repository in the "upstream", "github", "origin" +// git remote order, resolved to the parent repo if the git remote points to a fork +func (r resolvedRemotes) BaseRepo() (*api.Repository, error) { + if r.baseOverride != nil { + for _, repo := range r.network.Repositories { + if repo != nil && isSameRepo(repo, r.baseOverride) { + return repo, nil + } + } + return nil, fmt.Errorf("failed looking up information about the '%s/%s' repository", + r.baseOverride.RepoOwner(), r.baseOverride.RepoName()) + } + + for _, repo := range r.network.Repositories { + if repo == nil { + continue + } + if repo.IsFork() { + return repo.Parent, nil + } + return repo, nil + } + + return nil, errors.New("not found") +} + +// HeadRepo is the first found repository that has push access +func (r resolvedRemotes) HeadRepo() (*api.Repository, error) { + for _, repo := range r.network.Repositories { + if repo != nil && repo.ViewerCanPush() { + return repo.Parent, nil + } + } + return nil, errors.New("none of the repositories have push access") +} + +// RemoteForRepo finds the git remote that points to a repository +func (r resolvedRemotes) RemoteForRepo(repo api.Repo) (*context.Remote, error) { + for i, remote := range r.remotes { + if isSameRepo(remote, repo) || + // FIXME: express better that this is because of repo renames + (r.network.Repositories[i] != nil && isSameRepo(r.network.Repositories[i], repo)) { + return remote, nil + } + } + return nil, errors.New("not found") +} + +type ghRepo struct { + owner string + name string +} + +func repoFromFullName(nwo string) (r ghRepo) { + parts := strings.SplitN(nwo, "/", 2) + if len(parts) == 2 { + r.owner, r.name = parts[0], parts[1] + } + return +} + +func (r ghRepo) RepoOwner() string { + return r.owner +} +func (r ghRepo) RepoName() string { + return r.name +} + +func isSameRepo(a, b api.Repo) bool { + return strings.EqualFold(a.RepoOwner(), b.RepoOwner()) && + strings.EqualFold(a.RepoName(), b.RepoName()) +} diff --git a/context/remote.go b/context/remote.go index a660a85cb..3a4e7f772 100644 --- a/context/remote.go +++ b/context/remote.go @@ -35,6 +35,26 @@ func (r Remotes) FindByRepo(owner, name string) (*Remote, error) { return nil, fmt.Errorf("no matching remote found") } +func remoteNameSortScore(name string) int { + switch strings.ToLower(name) { + case "upstream": + return 3 + case "github": + return 2 + case "origin": + return 1 + default: + return 0 + } +} + +// https://golang.org/pkg/sort/#Interface +func (r Remotes) Len() int { return len(r) } +func (r Remotes) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r Remotes) Less(i, j int) bool { + return remoteNameSortScore(r[i].Name) > remoteNameSortScore(r[j].Name) +} + // Remote represents a git remote mapped to a GitHub repository type Remote struct { *git.Remote