diff --git a/api/pull_request_test.go b/api/pull_request_test.go index 82386108d..609a27e7c 100644 --- a/api/pull_request_test.go +++ b/api/pull_request_test.go @@ -22,7 +22,9 @@ func TestPullRequest_ChecksStatus(t *testing.T) { { "status": "COMPLETED", "conclusion": "FAILURE" }, { "status": "COMPLETED", - "conclusion": "ACTION_REQUIRED" } + "conclusion": "ACTION_REQUIRED" }, + { "status": "COMPLETED", + "conclusion": "STALE" } ] } } @@ -32,8 +34,8 @@ func TestPullRequest_ChecksStatus(t *testing.T) { eq(t, err, nil) checks := pr.ChecksStatus() - eq(t, checks.Total, 7) - eq(t, checks.Pending, 2) + eq(t, checks.Total, 8) + eq(t, checks.Pending, 3) eq(t, checks.Failing, 3) eq(t, checks.Passing, 2) } diff --git a/api/queries_pr.go b/api/queries_pr.go index b2e5f2f77..0f11f4392 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -120,7 +120,7 @@ func (pr *PullRequest) ChecksStatus() (summary PullRequestChecksStatus) { summary.Passing++ case "ERROR", "FAILURE", "CANCELLED", "TIMED_OUT", "ACTION_REQUIRED": summary.Failing++ - case "EXPECTED", "REQUESTED", "QUEUED", "PENDING", "IN_PROGRESS": + case "EXPECTED", "REQUESTED", "QUEUED", "PENDING", "IN_PROGRESS", "STALE": summary.Pending++ default: panic(fmt.Errorf("unsupported status: %q", state)) diff --git a/api/queries_repo.go b/api/queries_repo.go index e3d33b98a..5a5061eda 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -224,6 +225,49 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) { }, nil } +// RepoFindFork finds a fork of repo affiliated with the viewer +func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) { + result := struct { + Repository struct { + Forks struct { + Nodes []Repository + } + } + }{} + + variables := map[string]interface{}{ + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + } + + if err := client.GraphQL(` + query($owner: String!, $repo: String!) { + repository(owner: $owner, name: $repo) { + forks(first: 1, affiliations: [OWNER, COLLABORATOR]) { + nodes { + id + name + owner { login } + url + viewerPermission + } + } + } + } + `, variables, &result); err != nil { + return nil, err + } + + forks := result.Repository.Forks.Nodes + // we check ViewerCanPush, even though we expect it to always be true per + // `affiliations` condition, to guard against versions of GitHub with a + // faulty `affiliations` implementation + if len(forks) > 0 && forks[0].ViewerCanPush() { + return &forks[0], nil + } + return nil, &NotFoundError{errors.New("no fork found")} +} + // RepoCreateInput represents input parameters for RepoCreate type RepoCreateInput struct { Name string `json:"name"` diff --git a/command/pr.go b/command/pr.go index fcf186eb6..7c8799248 100644 --- a/command/pr.go +++ b/command/pr.go @@ -418,7 +418,9 @@ func printPrs(w io.Writer, totalCount int, prs ...api.PullRequest) { reviews := pr.ReviewStatus() if pr.State == "OPEN" { - if checks.Total > 0 || reviews.ChangesRequested || reviews.Approved { + reviewStatus := reviews.ChangesRequested || reviews.Approved || reviews.ReviewRequired + if checks.Total > 0 || reviewStatus { + // show checks & reviews on their own line fmt.Fprintf(w, "\n ") } @@ -426,24 +428,29 @@ func printPrs(w io.Writer, totalCount int, prs ...api.PullRequest) { var summary string if checks.Failing > 0 { if checks.Failing == checks.Total { - summary = utils.Red("All checks failing") + summary = utils.Red("× All checks failing") } else { - summary = utils.Red(fmt.Sprintf("%d/%d checks failing", checks.Failing, checks.Total)) + summary = utils.Red(fmt.Sprintf("× %d/%d checks failing", checks.Failing, checks.Total)) } } else if checks.Pending > 0 { - summary = utils.Yellow("Checks pending") + summary = utils.Yellow("- Checks pending") } else if checks.Passing == checks.Total { - summary = utils.Green("Checks passing") + summary = utils.Green("✓ Checks passing") } - fmt.Fprintf(w, " - %s", summary) + fmt.Fprint(w, summary) + } + + if checks.Total > 0 && reviewStatus { + // add padding between checks & reviews + fmt.Fprint(w, " ") } if reviews.ChangesRequested { - fmt.Fprintf(w, " - %s", utils.Red("Changes requested")) + fmt.Fprint(w, utils.Red("+ Changes requested")) } else if reviews.ReviewRequired { - fmt.Fprintf(w, " - %s", utils.Yellow("Review required")) + fmt.Fprint(w, utils.Yellow("- Review required")) } else if reviews.Approved { - fmt.Fprintf(w, " - %s", utils.Green("Approved")) + fmt.Fprint(w, utils.Green("✓ Approved")) } } else { s := strings.Title(strings.ToLower(pr.State)) diff --git a/command/pr_checkout.go b/command/pr_checkout.go index f23d0a083..5db8fbd07 100644 --- a/command/pr_checkout.go +++ b/command/pr_checkout.go @@ -67,7 +67,7 @@ func prCheckout(cmd *cobra.Command, args []string) error { cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec}) // local branch already exists - if git.VerifyRef("refs/heads/" + newBranchName) { + if _, err := git.ShowRefs("refs/heads/" + newBranchName); err == nil { cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName}) cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) } else { diff --git a/command/pr_checkout_test.go b/command/pr_checkout_test.go index 9aca71901..a66163ed1 100644 --- a/command/pr_checkout_test.go +++ b/command/pr_checkout_test.go @@ -46,7 +46,7 @@ func TestPRCheckout_sameRepo(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &errorStub{"exit status: 1"} default: ranCommands = append(ranCommands, cmd.Args) @@ -98,7 +98,7 @@ func TestPRCheckout_urlArg(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &errorStub{"exit status: 1"} default: ranCommands = append(ranCommands, cmd.Args) @@ -147,7 +147,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &errorStub{"exit status: 1"} default: ranCommands = append(ranCommands, cmd.Args) @@ -210,7 +210,7 @@ func TestPRCheckout_branchArg(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &errorStub{"exit status: 1"} default: ranCommands = append(ranCommands, cmd.Args) @@ -260,7 +260,7 @@ func TestPRCheckout_existingBranch(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &test.OutputStub{} default: ranCommands = append(ranCommands, cmd.Args) @@ -313,7 +313,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { ranCommands := [][]string{} restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable { switch strings.Join(cmd.Args, " ") { - case "git show-ref --verify --quiet refs/heads/feature": + case "git show-ref --verify -- refs/heads/feature": return &errorStub{"exit status: 1"} default: ranCommands = append(ranCommands, cmd.Args) diff --git a/command/pr_create.go b/command/pr_create.go index 2be922987..238057bf8 100644 --- a/command/pr_create.go +++ b/command/pr_create.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/url" + "strings" "time" "github.com/cli/cli/api" @@ -75,7 +76,27 @@ func prCreate(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("could not determine the current branch: %w", err) } - headRepo, headRepoErr := repoContext.HeadRepo() + + var headRepo ghrepo.Interface + var headRemote *context.Remote + + // determine whether the head branch is already pushed to a remote + headBranchPushedTo := determineTrackingBranch(remotes, headBranch) + if headBranchPushedTo != nil { + for _, r := range remotes { + if r.Name != headBranchPushedTo.RemoteName { + continue + } + headRepo = r + headRemote = r + break + } + } + + // otherwise, determine the head repository with info obtained from the API + if headRepo == nil { + headRepo, _ = repoContext.HeadRepo() + } baseBranch, err := cmd.Flags().GetString("base") if err != nil { @@ -193,8 +214,9 @@ func prCreate(cmd *cobra.Command, _ []string) error { } didForkRepo := false - var headRemote *context.Remote - if headRepoErr != nil { + // if a head repository could not be determined so far, automatically create + // one by forking the base repository + if headRepo == nil { if baseRepo.IsPrivate { return fmt.Errorf("cannot fork private repository '%s'", ghrepo.FullName(baseRepo)) } @@ -203,11 +225,25 @@ func prCreate(cmd *cobra.Command, _ []string) error { return fmt.Errorf("error forking repo: %w", err) } didForkRepo = true + } + + headBranchLabel := headBranch + if !ghrepo.IsSame(baseRepo, headRepo) { + headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) + } + + // There are two cases when an existing remote for the head repo will be + // missing: + // 1. the head repo was just created by auto-forking; + // 2. an existing fork was discovered by quering the API. + // + // In either case, we want to add the head repo as a new git remote so we + // can push to it. + if err != nil { // TODO: support non-HTTPS git remote URLs - baseRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(baseRepo)) headRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(headRepo)) - // TODO: figure out what to name the new git remote - gitRemote, err := git.AddRemote("fork", baseRepoURL, headRepoURL) + // TODO: prevent clashes with another remote of a same name + gitRemote, err := git.AddRemote("fork", headRepoURL) if err != nil { return fmt.Errorf("error adding remote: %w", err) } @@ -218,34 +254,31 @@ func prCreate(cmd *cobra.Command, _ []string) error { } } - headBranchLabel := headBranch - if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) - } - - if headRemote == nil { - headRemote, err = repoContext.RemoteForRepo(headRepo) - if err != nil { - return fmt.Errorf("git remote not found for head repository: %w", err) - } - } - - pushTries := 0 - maxPushTries := 3 - for { - // TODO: respect existing upstream configuration of the current branch - if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil { - if didForkRepo && pushTries < maxPushTries { - pushTries++ - // first wait 2 seconds after forking, then 4s, then 6s - waitSeconds := 2 * pushTries - fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second")) - time.Sleep(time.Duration(waitSeconds) * time.Second) - continue + // automatically push the branch if it hasn't been pushed anywhere yet + if headBranchPushedTo == nil { + if headRemote == nil { + headRemote, err = repoContext.RemoteForRepo(headRepo) + if err != nil { + return fmt.Errorf("git remote not found for head repository: %w", err) } - return err } - break + + pushTries := 0 + maxPushTries := 3 + for { + if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil { + if didForkRepo && pushTries < maxPushTries { + pushTries++ + // first wait 2 seconds after forking, then 4s, then 6s + waitSeconds := 2 * pushTries + fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second")) + time.Sleep(time.Duration(waitSeconds) * time.Second) + continue + } + return err + } + break + } } if action == SubmitAction { @@ -275,6 +308,47 @@ func prCreate(cmd *cobra.Command, _ []string) error { return nil } +func determineTrackingBranch(remotes context.Remotes, headBranch string) *git.TrackingRef { + refsForLookup := []string{"HEAD"} + var trackingRefs []git.TrackingRef + + headBranchConfig := git.ReadBranchConfig(headBranch) + if headBranchConfig.RemoteName != "" { + tr := git.TrackingRef{ + RemoteName: headBranchConfig.RemoteName, + BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), + } + trackingRefs = append(trackingRefs, tr) + refsForLookup = append(refsForLookup, tr.String()) + } + + for _, remote := range remotes { + tr := git.TrackingRef{ + RemoteName: remote.Name, + BranchName: headBranch, + } + trackingRefs = append(trackingRefs, tr) + refsForLookup = append(refsForLookup, tr.String()) + } + + resolvedRefs, _ := git.ShowRefs(refsForLookup...) + if len(resolvedRefs) > 1 { + for _, r := range resolvedRefs[1:] { + if r.Hash != resolvedRefs[0].Hash { + continue + } + for _, tr := range trackingRefs { + if tr.String() != r.Name { + continue + } + return &tr + } + } + } + + return nil +} + func generateCompareURL(r ghrepo.Interface, base, head, title, body string) string { u := fmt.Sprintf( "https://github.com/%s/compare/%s...%s?expand=1", diff --git a/command/pr_create_test.go b/command/pr_create_test.go index 76a7c8803..2791a083d 100644 --- a/command/pr_create_test.go +++ b/command/pr_create_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/cli/cli/context" + "github.com/cli/cli/git" "github.com/cli/cli/test" ) @@ -15,6 +16,10 @@ func TestPRCreate(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -28,6 +33,8 @@ func TestPRCreate(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git push @@ -35,7 +42,7 @@ func TestPRCreate(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create -t "my title" -b "my body"`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -62,6 +69,10 @@ func TestPRCreate_alreadyExists(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes": [ { "url": "https://github.com/OWNER/REPO/pull/123", @@ -73,6 +84,8 @@ func TestPRCreate_alreadyExists(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log @@ -89,6 +102,10 @@ func TestPRCreate_alreadyExistsDifferentBase(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes": [ { "url": "https://github.com/OWNER/REPO/pull/123", @@ -101,6 +118,8 @@ func TestPRCreate_alreadyExistsDifferentBase(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git rev-parse @@ -115,10 +134,16 @@ func TestPRCreate_web(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git push @@ -130,9 +155,9 @@ func TestPRCreate_web(t *testing.T) { eq(t, output.String(), "") eq(t, output.Stderr(), "Opening github.com/OWNER/REPO/compare/master...feature in your browser.\n") - eq(t, len(cs.Calls), 4) - eq(t, strings.Join(cs.Calls[2].Args, " "), "git push --set-upstream origin HEAD:feature") - browserCall := cs.Calls[3].Args + eq(t, len(cs.Calls), 6) + eq(t, strings.Join(cs.Calls[4].Args, " "), "git push --set-upstream origin HEAD:feature") + browserCall := cs.Calls[5].Args eq(t, browserCall[len(browserCall)-1], "https://github.com/OWNER/REPO/compare/master...feature?expand=1") } @@ -141,6 +166,10 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) { http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -154,6 +183,8 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub(" M git/git.go") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git push @@ -224,6 +255,8 @@ func TestPRCreate_cross_repo_same_branch(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git push @@ -260,6 +293,10 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { initBlankContext("OWNER/REPO", "cool_bug-fixes") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -273,6 +310,8 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,commit 0\n2345678901,commit 1") // git log cs.Stub("") // git rev-parse @@ -301,7 +340,7 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -330,6 +369,10 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -343,6 +386,8 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,the sky above the port") // git log cs.Stub("was the color of a television, turned to a dead channel") // git show @@ -372,7 +417,7 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -401,6 +446,10 @@ func TestPRCreate_survey_autofill(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -414,6 +463,8 @@ func TestPRCreate_survey_autofill(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("1234567890,the sky above the port") // git log cs.Stub("was the color of a television, turned to a dead channel") // git show @@ -424,7 +475,7 @@ func TestPRCreate_survey_autofill(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create -f`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -457,6 +508,8 @@ func TestPRCreate_defaults_error_autofill(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("") // git log @@ -473,6 +526,8 @@ func TestPRCreate_defaults_error_web(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("") // git log @@ -485,6 +540,10 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "createPullRequest": { "pullRequest": { "URL": "https://github.com/OWNER/REPO/pull/12" @@ -494,6 +553,8 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) { cs, cmdTeardown := test.InitCmdStubber() defer cmdTeardown() + cs.Stub("") // git config --get-regexp (determineTrackingBranch) + cs.Stub("") // git show-ref --verify (determineTrackingBranch) cs.Stub("") // git status cs.Stub("") // git log cs.Stub("") // git rev-parse @@ -526,3 +587,103 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) { stderr := string(output.Stderr()) eq(t, strings.Contains(stderr, "warning: could not compute title or body defaults: could not find any commits"), true) } + +func Test_determineTrackingBranch_empty(t *testing.T) { + cs, cmdTeardown := initCmdStubber() + defer cmdTeardown() + + remotes := context.Remotes{} + + cs.Stub("") // git config --get-regexp (ReadBranchConfig) + cs.Stub("deadbeef HEAD") // git show-ref --verify (ShowRefs) + + ref := determineTrackingBranch(remotes, "feature") + if ref != nil { + t.Errorf("expected nil result, got %v", ref) + } +} + +func Test_determineTrackingBranch_noMatch(t *testing.T) { + cs, cmdTeardown := initCmdStubber() + defer cmdTeardown() + + remotes := context.Remotes{ + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Owner: "hubot", + Repo: "Spoon-Knife", + }, + &context.Remote{ + Remote: &git.Remote{Name: "upstream"}, + Owner: "octocat", + Repo: "Spoon-Knife", + }, + } + + cs.Stub("") // git config --get-regexp (ReadBranchConfig) + cs.Stub(`deadbeef HEAD +deadb00f refs/remotes/origin/feature`) // git show-ref --verify (ShowRefs) + + ref := determineTrackingBranch(remotes, "feature") + if ref != nil { + t.Errorf("expected nil result, got %v", ref) + } +} + +func Test_determineTrackingBranch_hasMatch(t *testing.T) { + cs, cmdTeardown := initCmdStubber() + defer cmdTeardown() + + remotes := context.Remotes{ + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Owner: "hubot", + Repo: "Spoon-Knife", + }, + &context.Remote{ + Remote: &git.Remote{Name: "upstream"}, + Owner: "octocat", + Repo: "Spoon-Knife", + }, + } + + cs.Stub("") // git config --get-regexp (ReadBranchConfig) + cs.Stub(`deadbeef HEAD +deadb00f refs/remotes/origin/feature +deadbeef refs/remotes/upstream/feature`) // git show-ref --verify (ShowRefs) + + ref := determineTrackingBranch(remotes, "feature") + if ref == nil { + t.Fatal("expected result, got nil") + } + + eq(t, cs.Calls[1].Args, []string{"git", "show-ref", "--verify", "--", "HEAD", "refs/remotes/origin/feature", "refs/remotes/upstream/feature"}) + + eq(t, ref.RemoteName, "upstream") + eq(t, ref.BranchName, "feature") +} + +func Test_determineTrackingBranch_respectTrackingConfig(t *testing.T) { + cs, cmdTeardown := initCmdStubber() + defer cmdTeardown() + + remotes := context.Remotes{ + &context.Remote{ + Remote: &git.Remote{Name: "origin"}, + Owner: "hubot", + Repo: "Spoon-Knife", + }, + } + + cs.Stub(`branch.feature.remote origin +branch.feature.merge refs/heads/great-feat`) // git config --get-regexp (ReadBranchConfig) + cs.Stub(`deadbeef HEAD +deadb00f refs/remotes/origin/feature`) // git show-ref --verify (ShowRefs) + + ref := determineTrackingBranch(remotes, "feature") + if ref != nil { + t.Errorf("expected nil result, got %v", ref) + } + + eq(t, cs.Calls[1].Args, []string{"git", "show-ref", "--verify", "--", "HEAD", "refs/remotes/origin/great-feat", "refs/remotes/origin/feature"}) +} diff --git a/command/pr_test.go b/command/pr_test.go index 8a15dfe51..73d58bfb8 100644 --- a/command/pr_test.go +++ b/command/pr_test.go @@ -144,9 +144,9 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) { } expected := []string{ - "- Checks passing - Changes requested", - "- Checks pending - Approved", - "- 1/3 checks failing - Review required", + "✓ Checks passing + Changes requested", + "- Checks pending ✓ Approved", + "× 1/3 checks failing - Review required", } for _, line := range expected { diff --git a/command/repo.go b/command/repo.go index 60ca2778e..959cfbaec 100644 --- a/command/repo.go +++ b/command/repo.go @@ -304,14 +304,6 @@ func repoFork(cmd *cobra.Command, args []string) error { s.FinalMSG = utils.Gray(fmt.Sprintf("- %s\n", loading)) s.Start() - authLogin, err := ctx.AuthLogin() - if err != nil { - s.Stop() - return fmt.Errorf("could not determine current username: %w", err) - } - - possibleFork := ghrepo.New(authLogin, toFork.RepoName()) - forkedRepo, err := api.ForkRepo(apiClient, toFork) if err != nil { s.Stop() @@ -324,11 +316,11 @@ func repoFork(cmd *cobra.Command, args []string) error { // returns the fork repo data even if it already exists -- with no change in status code or // anything. We thus check the created time to see if the repo is brand new or not; if it's not, // we assume the fork already existed and report an error. - created_ago := Since(forkedRepo.CreatedAt) - if created_ago > time.Minute { + createdAgo := Since(forkedRepo.CreatedAt) + if createdAgo > time.Minute { fmt.Fprintf(out, "%s %s %s\n", utils.Yellow("!"), - utils.Bold(ghrepo.FullName(possibleFork)), + utils.Bold(ghrepo.FullName(forkedRepo)), "already exists") } else { fmt.Fprintf(out, "%s Created fork %s\n", greenCheck, utils.Bold(ghrepo.FullName(forkedRepo))) @@ -339,6 +331,15 @@ func repoFork(cmd *cobra.Command, args []string) error { } if inParent { + remotes, err := ctx.Remotes() + if err != nil { + return err + } + if remote, err := remotes.FindByRepo(forkedRepo.RepoOwner(), forkedRepo.RepoName()); err == nil { + fmt.Fprintf(out, "%s Using existing remote %s\n", greenCheck, utils.Bold(remote.Name)) + return nil + } + remoteDesired := remotePref == "true" if remotePref == "prompt" { err = Confirm("Would you like to add a remote for the fork?", &remoteDesired) @@ -347,7 +348,7 @@ func repoFork(cmd *cobra.Command, args []string) error { } } if remoteDesired { - _, err := git.AddRemote("fork", forkedRepo.CloneURL, "") + _, err := git.AddRemote("fork", forkedRepo.CloneURL) if err != nil { return fmt.Errorf("failed to add remote: %w", err) } diff --git a/command/repo_test.go b/command/repo_test.go index 09291462a..1b9aa2642 100644 --- a/command/repo_test.go +++ b/command/repo_test.go @@ -18,9 +18,8 @@ import ( func TestRepoFork_already_forked(t *testing.T) { initContext = func() context.Context { ctx := context.NewBlank() - ctx.SetBaseRepo("REPO") + ctx.SetBaseRepo("OWNER/REPO") ctx.SetBranch("master") - ctx.SetAuthLogin("someone") ctx.SetRemotes(map[string]string{ "origin": "OWNER/REPO", }) @@ -41,6 +40,31 @@ func TestRepoFork_already_forked(t *testing.T) { } } +func TestRepoFork_reuseRemote(t *testing.T) { + initContext = func() context.Context { + ctx := context.NewBlank() + ctx.SetBaseRepo("OWNER/REPO") + ctx.SetBranch("master") + ctx.SetRemotes(map[string]string{ + "upstream": "OWNER/REPO", + "origin": "someone/REPO", + }) + return ctx + } + http := initFakeHTTP() + http.StubRepoResponse("OWNER", "REPO") + defer http.StubWithFixture(200, "forkResult.json")() + + output, err := RunCommand(repoForkCmd, "repo fork") + if err != nil { + t.Errorf("got unexpected error: %v", err) + } + if !strings.Contains(output.String(), "Using existing remote origin") { + t.Errorf("output did not match: %q", output) + return + } +} + func stubSince(d time.Duration) func() { originalSince := Since Since = func(t time.Time) time.Duration { diff --git a/context/context.go b/context/context.go index 23ce74d85..eb1ace0eb 100644 --- a/context/context.go +++ b/context/context.go @@ -51,7 +51,10 @@ func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (Re repos = append(repos, baseOverride) } - result := ResolvedRemotes{Remotes: remotes} + result := ResolvedRemotes{ + Remotes: remotes, + apiClient: client, + } if hasBaseOverride { result.BaseOverride = baseOverride } @@ -67,6 +70,7 @@ type ResolvedRemotes struct { BaseOverride ghrepo.Interface Remotes Remotes Network api.RepoNetworkResult + apiClient *api.Client } // BaseRepo is the first found repository in the "upstream", "github", "origin" @@ -95,8 +99,30 @@ func (r ResolvedRemotes) BaseRepo() (*api.Repository, error) { return nil, errors.New("not found") } -// HeadRepo is the first found repository that has push access +// HeadRepo is a fork of base repo (if any), or the first found repository that +// has push access func (r ResolvedRemotes) HeadRepo() (*api.Repository, error) { + baseRepo, err := r.BaseRepo() + if err != nil { + return nil, err + } + + // try to find a pushable fork among existing remotes + for _, repo := range r.Network.Repositories { + if repo != nil && repo.Parent != nil && repo.ViewerCanPush() && ghrepo.IsSame(repo.Parent, baseRepo) { + return repo, nil + } + } + + // a fork might still exist on GitHub, so let's query for it + var notFound *api.NotFoundError + if repo, err := api.RepoFindFork(r.apiClient, baseRepo); err == nil { + return repo, nil + } else if !errors.As(err, ¬Found) { + return nil, err + } + + // fall back to any listed repository that has push access for _, repo := range r.Network.Repositories { if repo != nil && repo.ViewerCanPush() { return repo, nil diff --git a/context/remote_test.go b/context/remote_test.go index 04ccbc56c..5c3c46bc7 100644 --- a/context/remote_test.go +++ b/context/remote_test.go @@ -1,6 +1,7 @@ package context import ( + "bytes" "errors" "net/url" "testing" @@ -61,6 +62,14 @@ func Test_translateRemotes(t *testing.T) { } func Test_resolvedRemotes_triangularSetup(t *testing.T) { + http := &api.FakeHTTP{} + apiClient := api.NewClient(api.ReplaceTripper(http)) + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) + resolved := ResolvedRemotes{ BaseOverride: nil, Remotes: Remotes{ @@ -89,6 +98,7 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) { }, }, }, + apiClient: apiClient, } baseRepo, err := resolved.BaseRepo() @@ -118,6 +128,53 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) { } } +func Test_resolvedRemotes_forkLookup(t *testing.T) { + http := &api.FakeHTTP{} + apiClient := api.NewClient(api.ReplaceTripper(http)) + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + { "id": "FORKID", + "url": "https://github.com/FORKOWNER/REPO", + "name": "REPO", + "owner": { "login": "FORKOWNER" }, + "viewerPermission": "WRITE" + } + ] } } } } + `)) + + resolved := ResolvedRemotes{ + BaseOverride: nil, + Remotes: Remotes{ + &Remote{ + Remote: &git.Remote{Name: "origin"}, + Owner: "OWNER", + Repo: "REPO", + }, + }, + Network: api.RepoNetworkResult{ + Repositories: []*api.Repository{ + &api.Repository{ + Name: "NEWNAME", + Owner: api.RepositoryOwner{Login: "NEWOWNER"}, + ViewerPermission: "READ", + }, + }, + }, + apiClient: apiClient, + } + + headRepo, err := resolved.HeadRepo() + if err != nil { + t.Fatalf("got %v", err) + } + eq(t, ghrepo.FullName(headRepo), "FORKOWNER/REPO") + _, err = resolved.RemoteForRepo(headRepo) + if err == nil { + t.Fatal("expected to not find a matching remote") + } +} + func Test_resolvedRemotes_clonedFork(t *testing.T) { resolved := ResolvedRemotes{ BaseOverride: nil, diff --git a/git/git.go b/git/git.go index a4fae4408..360917eab 100644 --- a/git/git.go +++ b/git/git.go @@ -13,10 +13,41 @@ import ( "github.com/cli/cli/internal/run" ) -func VerifyRef(ref string) bool { - showRef := exec.Command("git", "show-ref", "--verify", "--quiet", ref) - err := run.PrepareCmd(showRef).Run() - return err == nil +// Ref represents a git commit reference +type Ref struct { + Hash string + Name string +} + +// TrackingRef represents a ref for a remote tracking branch +type TrackingRef struct { + RemoteName string + BranchName string +} + +func (r TrackingRef) String() string { + return "refs/remotes/" + r.RemoteName + "/" + r.BranchName +} + +// ShowRefs resolves fully-qualified refs to commit hashes +func ShowRefs(ref ...string) ([]Ref, error) { + args := append([]string{"show-ref", "--verify", "--"}, ref...) + showRef := exec.Command("git", args...) + output, err := run.PrepareCmd(showRef).Output() + + var refs []Ref + for _, line := range outputLines(output) { + parts := strings.SplitN(line, " ", 2) + if len(parts) < 2 { + continue + } + refs = append(refs, Ref{ + Hash: parts[0], + Name: parts[1], + }) + } + + return refs, err } // CurrentBranch reads the checked-out branch for the git repository @@ -162,7 +193,7 @@ func ReadBranchConfig(branch string) (cfg BranchConfig) { continue } cfg.RemoteURL = u - } else { + } else if !isFilesystemPath(parts[1]) { cfg.RemoteName = parts[1] } case "merge": @@ -172,6 +203,10 @@ func ReadBranchConfig(branch string) (cfg BranchConfig) { return } +func isFilesystemPath(p string) bool { + return p == "." || strings.HasPrefix(p, "./") || strings.HasPrefix(p, "/") +} + // ToplevelDir returns the top-level directory path of the current repository func ToplevelDir() (string, error) { showCmd := exec.Command("git", "rev-parse", "--show-toplevel") diff --git a/git/remote.go b/git/remote.go index bfecfd3b0..6c8608da9 100644 --- a/git/remote.go +++ b/git/remote.go @@ -71,34 +71,22 @@ func parseRemotes(gitRemotes []string) (remotes RemoteSet) { return } -// AddRemote adds a new git remote. The initURL is the remote URL with which the -// automatic fetch is made and finalURL, if non-blank, is set as the remote URL -// after the fetch. -func AddRemote(name, initURL, finalURL string) (*Remote, error) { - addCmd := exec.Command("git", "remote", "add", "-f", name, initURL) +// AddRemote adds a new git remote and auto-fetches objects from it +func AddRemote(name, u string) (*Remote, error) { + addCmd := exec.Command("git", "remote", "add", "-f", name, u) err := run.PrepareCmd(addCmd).Run() if err != nil { return nil, err } - if finalURL == "" { - finalURL = initURL - } else { - setCmd := exec.Command("git", "remote", "set-url", name, finalURL) - err := run.PrepareCmd(setCmd).Run() - if err != nil { - return nil, err - } - } - - finalURLParsed, err := url.Parse(finalURL) + urlParsed, err := url.Parse(u) if err != nil { return nil, err } return &Remote{ Name: name, - FetchURL: finalURLParsed, - PushURL: finalURLParsed, + FetchURL: urlParsed, + PushURL: urlParsed, }, nil } diff --git a/internal/ghrepo/repo.go b/internal/ghrepo/repo.go index f683d9249..a4bfa82d6 100644 --- a/internal/ghrepo/repo.go +++ b/internal/ghrepo/repo.go @@ -8,21 +8,26 @@ import ( const defaultHostname = "github.com" +// Interface describes an object that represents a GitHub repository type Interface interface { RepoName() string RepoOwner() string } +// New instantiates a GitHub repository from owner and name arguments func New(owner, repo string) Interface { return &ghRepo{ owner: owner, name: repo, } } + +// FullName serializes a GitHub repository into an "OWNER/REPO" string func FullName(r Interface) string { return fmt.Sprintf("%s/%s", r.RepoOwner(), r.RepoName()) } +// FromFullName extracts the GitHub repository inforation from an "OWNER/REPO" string func FromFullName(nwo string) Interface { var r ghRepo parts := strings.SplitN(nwo, "/", 2) @@ -32,8 +37,9 @@ func FromFullName(nwo string) Interface { return &r } +// FromURL extracts the GitHub repository information from a URL func FromURL(u *url.URL) (Interface, error) { - if !strings.EqualFold(u.Hostname(), defaultHostname) { + if !strings.EqualFold(u.Hostname(), defaultHostname) && !strings.EqualFold(u.Hostname(), "www."+defaultHostname) { return nil, fmt.Errorf("unsupported hostname: %s", u.Hostname()) } parts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", 3) @@ -43,6 +49,7 @@ func FromURL(u *url.URL) (Interface, error) { return New(parts[0], strings.TrimSuffix(parts[1], ".git")), nil } +// IsSame compares two GitHub repositories func IsSame(a, b Interface) bool { return strings.EqualFold(a.RepoOwner(), b.RepoOwner()) && strings.EqualFold(a.RepoName(), b.RepoName()) diff --git a/internal/ghrepo/repo_test.go b/internal/ghrepo/repo_test.go index 6ff775c84..fef04fb8c 100644 --- a/internal/ghrepo/repo_test.go +++ b/internal/ghrepo/repo_test.go @@ -1,40 +1,66 @@ package ghrepo import ( + "errors" + "fmt" "net/url" "testing" ) func Test_repoFromURL(t *testing.T) { - u, _ := url.Parse("http://github.com/monalisa/octo-cat.git") - repo, err := FromURL(u) - if err != nil { - t.Fatalf("got error %q", err) + tests := []struct { + name string + input string + result string + err error + }{ + { + name: "github.com URL", + input: "https://github.com/monalisa/octo-cat.git", + result: "monalisa/octo-cat", + err: nil, + }, + { + name: "www.github.com URL", + input: "http://www.GITHUB.com/monalisa/octo-cat.git", + result: "monalisa/octo-cat", + err: nil, + }, + { + name: "unsupported hostname", + input: "https://example.com/one/two", + result: "", + err: errors.New("unsupported hostname: example.com"), + }, + { + name: "filesystem path", + input: "/path/to/file", + result: "", + err: errors.New("unsupported hostname: "), + }, } - if repo.RepoOwner() != "monalisa" { - t.Errorf("got owner %q", repo.RepoOwner()) - } - if repo.RepoName() != "octo-cat" { - t.Errorf("got name %q", repo.RepoName()) - } -} -func Test_repoFromURL_invalid(t *testing.T) { - cases := [][]string{ - []string{ - "https://example.com/one/two", - "unsupported hostname: example.com", - }, - []string{ - "/path/to/disk", - "unsupported hostname: ", - }, - } - for _, c := range cases { - u, _ := url.Parse(c[0]) - _, err := FromURL(u) - if err == nil || err.Error() != c[1] { - t.Errorf("got %q", err) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.input) + if err != nil { + t.Fatalf("got error %q", err) + } + + repo, err := FromURL(u) + if err != nil { + if tt.err == nil { + t.Fatalf("got error %q", err) + } else if tt.err.Error() == err.Error() { + return + } + t.Fatalf("got error %q", err) + } + + got := fmt.Sprintf("%s/%s", repo.RepoOwner(), repo.RepoName()) + if tt.result != got { + t.Errorf("expected %q, got %q", tt.result, got) + } + }) } }