diff --git a/api/queries_repo.go b/api/queries_repo.go index 048b72b50..79068f0ea 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -3,6 +3,7 @@ package api import ( "bytes" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -220,6 +221,44 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) { }, nil } +// RepoFindFork finds a fork of repo affiliated with the viewer +func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) { + result := struct { + Repository struct { + Forks struct { + Nodes []Repository + } + } + }{} + + variables := map[string]interface{}{ + "owner": repo.RepoOwner(), + "repo": repo.RepoName(), + } + + if err := client.GraphQL(` + query($owner: String!, $repo: String!) { + repository(owner: $owner, name: $repo) { + forks(first: 1, affiliations: [OWNER, COLLABORATOR]) { + nodes { + id + name + owner { login } + url + } + } + } + } + `, variables, &result); err != nil { + return nil, err + } + + if len(result.Repository.Forks.Nodes) > 0 { + return &result.Repository.Forks.Nodes[0], nil + } + return nil, &NotFoundError{errors.New("no fork found")} +} + // RepoCreateInput represents input parameters for RepoCreate type RepoCreateInput struct { Name string `json:"name"` diff --git a/command/pr_create.go b/command/pr_create.go index ca7800a19..e4a4bd2de 100644 --- a/command/pr_create.go +++ b/command/pr_create.go @@ -193,7 +193,6 @@ func prCreate(cmd *cobra.Command, _ []string) error { } didForkRepo := false - var headRemote *context.Remote if headRepoErr != nil { if baseRepo.IsPrivate { return fmt.Errorf("cannot fork private repository '%s'", ghrepo.FullName(baseRepo)) @@ -203,19 +202,6 @@ func prCreate(cmd *cobra.Command, _ []string) error { return fmt.Errorf("error forking repo: %w", err) } didForkRepo = true - // TODO: support non-HTTPS git remote URLs - baseRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(baseRepo)) - headRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(headRepo)) - // TODO: figure out what to name the new git remote - gitRemote, err := git.AddRemote("fork", baseRepoURL, headRepoURL) - if err != nil { - return fmt.Errorf("error adding remote: %w", err) - } - headRemote = &context.Remote{ - Remote: gitRemote, - Owner: headRepo.RepoOwner(), - Repo: headRepo.RepoName(), - } } headBranchLabel := headBranch @@ -223,10 +209,26 @@ func prCreate(cmd *cobra.Command, _ []string) error { headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) } - if headRemote == nil { - headRemote, err = repoContext.RemoteForRepo(headRepo) + headRemote, err := repoContext.RemoteForRepo(headRepo) + // There are two cases when an existing remote for the head repo will be + // missing: + // 1. the head repo was just created by auto-forking; + // 2. an existing fork was discovered by quering the API. + // + // In either case, we want to add the head repo as a new git remote so we + // can push to it. + if err != nil { + // TODO: support non-HTTPS git remote URLs + headRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(headRepo)) + // TODO: prevent clashes with another remote of a same name + gitRemote, err := git.AddRemote("fork", headRepoURL, "") if err != nil { - return fmt.Errorf("git remote not found for head repository: %w", err) + return fmt.Errorf("error adding remote: %w", err) + } + headRemote = &context.Remote{ + Remote: gitRemote, + Owner: headRepo.RepoOwner(), + Repo: headRepo.RepoName(), } } diff --git a/command/pr_create_test.go b/command/pr_create_test.go index 688dfd9c8..ebffb686a 100644 --- a/command/pr_create_test.go +++ b/command/pr_create_test.go @@ -14,6 +14,10 @@ func TestPRCreate(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -34,7 +38,7 @@ func TestPRCreate(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create -t "my title" -b "my body"`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -61,6 +65,10 @@ func TestPRCreate_alreadyExists(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes": [ { "url": "https://github.com/OWNER/REPO/pull/123", @@ -87,6 +95,10 @@ func TestPRCreate_web(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) cs, cmdTeardown := initCmdStubber() defer cmdTeardown() @@ -113,6 +125,10 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) { http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -232,6 +248,10 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { initBlankContext("OWNER/REPO", "cool_bug-fixes") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -273,7 +293,7 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -302,6 +322,10 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -344,7 +368,7 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -373,6 +397,10 @@ func TestPRCreate_survey_autofill(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes" : [ ] } } } } @@ -396,7 +424,7 @@ func TestPRCreate_survey_autofill(t *testing.T) { output, err := RunCommand(prCreateCmd, `pr create -f`) eq(t, err, nil) - bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body) + bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body) reqBody := struct { Variables struct { Input struct { @@ -457,6 +485,10 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) { initBlankContext("OWNER/REPO", "feature") http := initFakeHTTP() http.StubRepoResponse("OWNER", "REPO") + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) http.StubResponse(200, bytes.NewBufferString(` { "data": { "createPullRequest": { "pullRequest": { "URL": "https://github.com/OWNER/REPO/pull/12" diff --git a/context/context.go b/context/context.go index 6c3bd5aee..27744a08a 100644 --- a/context/context.go +++ b/context/context.go @@ -51,7 +51,10 @@ func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (Re repos = append(repos, baseOverride) } - result := ResolvedRemotes{Remotes: remotes} + result := ResolvedRemotes{ + Remotes: remotes, + apiClient: client, + } if hasBaseOverride { result.BaseOverride = baseOverride } @@ -67,6 +70,7 @@ type ResolvedRemotes struct { BaseOverride ghrepo.Interface Remotes Remotes Network api.RepoNetworkResult + apiClient *api.Client } // BaseRepo is the first found repository in the "upstream", "github", "origin" @@ -95,8 +99,30 @@ func (r ResolvedRemotes) BaseRepo() (*api.Repository, error) { return nil, errors.New("not found") } -// HeadRepo is the first found repository that has push access +// HeadRepo is a fork of base repo (if any), or the first found repository that +// has push access func (r ResolvedRemotes) HeadRepo() (*api.Repository, error) { + baseRepo, err := r.BaseRepo() + if err != nil { + return nil, err + } + + // try to find a pushable fork among existing remotes + for _, repo := range r.Network.Repositories { + if repo != nil && repo.Parent != nil && repo.ViewerCanPush() && ghrepo.IsSame(repo.Parent, baseRepo) { + return repo, nil + } + } + + // a fork might still exist on GitHub, so let's query for it + var notFound *api.NotFoundError + if repo, err := api.RepoFindFork(r.apiClient, baseRepo); err == nil { + return repo, nil + } else if !errors.As(err, ¬Found) { + return nil, err + } + + // fall back to any listed repository that has push access for _, repo := range r.Network.Repositories { if repo != nil && repo.ViewerCanPush() { return repo, nil diff --git a/context/remote_test.go b/context/remote_test.go index 04ccbc56c..dd6a5d59a 100644 --- a/context/remote_test.go +++ b/context/remote_test.go @@ -1,6 +1,7 @@ package context import ( + "bytes" "errors" "net/url" "testing" @@ -61,6 +62,14 @@ func Test_translateRemotes(t *testing.T) { } func Test_resolvedRemotes_triangularSetup(t *testing.T) { + http := &api.FakeHTTP{} + apiClient := api.NewClient(api.ReplaceTripper(http)) + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + ] } } } } + `)) + resolved := ResolvedRemotes{ BaseOverride: nil, Remotes: Remotes{ @@ -89,6 +98,7 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) { }, }, }, + apiClient: apiClient, } baseRepo, err := resolved.BaseRepo() @@ -118,6 +128,52 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) { } } +func Test_resolvedRemotes_forkLookup(t *testing.T) { + http := &api.FakeHTTP{} + apiClient := api.NewClient(api.ReplaceTripper(http)) + + http.StubResponse(200, bytes.NewBufferString(` + { "data": { "repository": { "forks": { "nodes": [ + { "id": "FORKID", + "url": "https://github.com/FORKOWNER/REPO", + "name": "REPO", + "owner": { "login": "FORKOWNER" } + } + ] } } } } + `)) + + resolved := ResolvedRemotes{ + BaseOverride: nil, + Remotes: Remotes{ + &Remote{ + Remote: &git.Remote{Name: "origin"}, + Owner: "OWNER", + Repo: "REPO", + }, + }, + Network: api.RepoNetworkResult{ + Repositories: []*api.Repository{ + &api.Repository{ + Name: "NEWNAME", + Owner: api.RepositoryOwner{Login: "NEWOWNER"}, + ViewerPermission: "READ", + }, + }, + }, + apiClient: apiClient, + } + + headRepo, err := resolved.HeadRepo() + if err != nil { + t.Fatalf("got %v", err) + } + eq(t, ghrepo.FullName(headRepo), "FORKOWNER/REPO") + _, err = resolved.RemoteForRepo(headRepo) + if err == nil { + t.Fatal("expected to not find a matching remote") + } +} + func Test_resolvedRemotes_clonedFork(t *testing.T) { resolved := ResolvedRemotes{ BaseOverride: nil,