diff --git a/pkg/cmd/issue/edit/edit_test.go b/pkg/cmd/issue/edit/edit_test.go index acc3b62fa..2b0cd87da 100644 --- a/pkg/cmd/issue/edit/edit_test.go +++ b/pkg/cmd/issue/edit/edit_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "testing" + "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/ghrepo" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -286,6 +287,11 @@ func Test_editRun(t *testing.T) { Value: "GA", Edited: true, }, + Metadata: api.RepoMetadataResult{ + Labels: []api.RepoLabel{ + {Name: "docs", ID: "DOCSID"}, + }, + }, }, FetchOptions: prShared.FetchOptions, }, @@ -293,6 +299,7 @@ func Test_editRun(t *testing.T) { mockIssueGet(t, reg) mockRepoMetadata(t, reg) mockIssueUpdate(t, reg) + mockIssueUpdateLabels(t, reg) }, stdout: "https://github.com/OWNER/REPO/issue/123\n", }, @@ -386,7 +393,8 @@ func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry) { "nodes": [ { "name": "feature", "id": "FEATUREID" }, { "name": "TODO", "id": "TODOID" }, - { "name": "bug", "id": "BUGID" } + { "name": "bug", "id": "BUGID" }, + { "name": "docs", "id": "DOCSID" } ], "pageInfo": { "hasNextPage": false } } } } } @@ -429,9 +437,22 @@ func mockIssueUpdate(t *testing.T, reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`mutation IssueUpdate\b`), httpmock.GraphQLMutation(` - { "data": { "updateIssue": { "issue": { - "id": "123" - } } } }`, + { "data": { "updateIssue": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) +} + +func mockIssueUpdateLabels(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation LabelAdd\b`), + httpmock.GraphQLMutation(` + { "data": { "addLabelsToLabelable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) + reg.Register( + httpmock.GraphQL(`mutation LabelRemove\b`), + httpmock.GraphQLMutation(` + { "data": { "removeLabelsFromLabelable": { "__typename": "" } } }`, func(inputs map[string]interface{}) {}), ) } diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go index f69a70e79..bad2cfa8f 100644 --- a/pkg/cmd/pr/edit/edit.go +++ b/pkg/cmd/pr/edit/edit.go @@ -13,6 +13,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/shurcooL/githubv4" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) type EditOptions struct { @@ -214,16 +215,19 @@ func editRun(opts *EditOptions) error { } func updatePullRequest(httpClient *http.Client, repo ghrepo.Interface, id string, editable shared.Editable) error { - if err := shared.UpdateIssue(httpClient, repo, id, true, editable); err != nil { - return err + var wg errgroup.Group + wg.Go(func() error { + return shared.UpdateIssue(httpClient, repo, id, true, editable) + }) + if editable.Reviewers.Edited { + wg.Go(func() error { + return updatePullRequestReviews(httpClient, repo, id, editable) + }) } - return updatePullRequestReviews(httpClient, repo, id, editable) + return wg.Wait() } func updatePullRequestReviews(httpClient *http.Client, repo ghrepo.Interface, id string, editable shared.Editable) error { - if !editable.Reviewers.Edited { - return nil - } userIds, teamIds, err := editable.ReviewerIds() if err != nil { return err diff --git a/pkg/cmd/pr/edit/edit_test.go b/pkg/cmd/pr/edit/edit_test.go index 6dcbdbe20..00ff74382 100644 --- a/pkg/cmd/pr/edit/edit_test.go +++ b/pkg/cmd/pr/edit/edit_test.go @@ -357,6 +357,7 @@ func Test_editRun(t *testing.T) { mockRepoMetadata(t, reg, false) mockPullRequestUpdate(t, reg) mockPullRequestReviewersUpdate(t, reg) + mockPullRequestUpdateLabels(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -387,7 +388,7 @@ func Test_editRun(t *testing.T) { Edited: true, }, Labels: shared.EditableSlice{ - Value: []string{"feature", "TODO", "bug"}, + Add: []string{"feature", "TODO", "bug"}, Remove: []string{"docs"}, Edited: true, }, @@ -406,6 +407,7 @@ func Test_editRun(t *testing.T) { httpStubs: func(t *testing.T, reg *httpmock.Registry) { mockRepoMetadata(t, reg, true) mockPullRequestUpdate(t, reg) + mockPullRequestUpdateLabels(t, reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, @@ -490,7 +492,8 @@ func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry, skipReviewers bool) "nodes": [ { "name": "feature", "id": "FEATUREID" }, { "name": "TODO", "id": "TODOID" }, - { "name": "bug", "id": "BUGID" } + { "name": "bug", "id": "BUGID" }, + { "name": "docs", "id": "DOCSID" } ], "pageInfo": { "hasNextPage": false } } } } } @@ -554,6 +557,21 @@ func mockPullRequestReviewersUpdate(t *testing.T, reg *httpmock.Registry) { httpmock.StringResponse(`{}`)) } +func mockPullRequestUpdateLabels(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation LabelAdd\b`), + httpmock.GraphQLMutation(` + { "data": { "addLabelsToLabelable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) + reg.Register( + httpmock.GraphQL(`mutation LabelRemove\b`), + httpmock.GraphQLMutation(` + { "data": { "removeLabelsFromLabelable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) +} + type testFetcher struct{} type testSurveyor struct { skipReviewers bool diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index 8037c5aec..57b042937 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -120,21 +120,6 @@ func (e Editable) AssigneeIds(client *api.Client, repo ghrepo.Interface) (*[]str return &a, err } -func (e Editable) LabelIds() (*[]string, error) { - if !e.Labels.Edited { - return nil, nil - } - if len(e.Labels.Add) != 0 || len(e.Labels.Remove) != 0 { - s := set.NewStringSet() - s.AddValues(e.Labels.Default) - s.AddValues(e.Labels.Add) - s.RemoveValues(e.Labels.Remove) - e.Labels.Value = s.ToSlice() - } - l, err := e.Metadata.LabelsToIDs(e.Labels.Value) - return &l, err -} - func (e Editable) ProjectIds() (*[]string, error) { if !e.Projects.Edited { return nil, nil @@ -189,10 +174,22 @@ func EditFieldsSurvey(editable *Editable, editorCommand string) error { } } if editable.Labels.Edited { - editable.Labels.Value, err = multiSelectSurvey("Labels", editable.Labels.Default, editable.Labels.Options) + editable.Labels.Add, err = multiSelectSurvey("Labels", editable.Labels.Default, editable.Labels.Options) if err != nil { return err } + for _, prev := range editable.Labels.Default { + var found bool + for _, selected := range editable.Labels.Add { + if prev == selected { + found = true + break + } + } + if !found { + editable.Labels.Remove = append(editable.Labels.Remove, prev) + } + } } if editable.Projects.Edited { editable.Projects.Value, err = multiSelectSurvey("Projects", editable.Projects.Default, editable.Projects.Options) diff --git a/pkg/cmd/pr/shared/editable_http.go b/pkg/cmd/pr/shared/editable_http.go index 8af23c6c3..f9b19cb07 100644 --- a/pkg/cmd/pr/shared/editable_http.go +++ b/pkg/cmd/pr/shared/editable_http.go @@ -9,19 +9,47 @@ import ( "github.com/cli/cli/v2/internal/ghrepo" graphql "github.com/cli/shurcooL-graphql" "github.com/shurcooL/githubv4" + "golang.org/x/sync/errgroup" ) func UpdateIssue(httpClient *http.Client, repo ghrepo.Interface, id string, isPR bool, options Editable) error { - title := ghString(options.TitleValue()) - body := ghString(options.BodyValue()) + var wg errgroup.Group - apiClient := api.NewClientFromHTTP(httpClient) - assigneeIds, err := options.AssigneeIds(apiClient, repo) - if err != nil { - return err + // Labels are updated through discrete mutations to avoid having to replace the entire list of labels + // and risking race conditions. + if options.Labels.Edited { + if len(options.Labels.Add) > 0 { + wg.Go(func() error { + addedLabelIds, err := options.Metadata.LabelsToIDs(options.Labels.Add) + if err != nil { + return err + } + return addLabels(httpClient, id, repo, addedLabelIds) + }) + } + if len(options.Labels.Remove) > 0 { + wg.Go(func() error { + removeLabelIds, err := options.Metadata.LabelsToIDs(options.Labels.Remove) + if err != nil { + return err + } + return removeLabels(httpClient, id, repo, removeLabelIds) + }) + } } - labelIds, err := options.LabelIds() + if dirtyExcludingLabels(options) { + wg.Go(func() error { + return replaceIssueFields(httpClient, repo, id, isPR, options) + }) + } + + return wg.Wait() +} + +func replaceIssueFields(httpClient *http.Client, repo ghrepo.Interface, id string, isPR bool, options Editable) error { + apiClient := api.NewClientFromHTTP(httpClient) + assigneeIds, err := options.AssigneeIds(apiClient, repo) if err != nil { return err } @@ -39,10 +67,9 @@ func UpdateIssue(httpClient *http.Client, repo ghrepo.Interface, id string, isPR if isPR { params := githubv4.UpdatePullRequestInput{ PullRequestID: id, - Title: title, - Body: body, + Title: ghString(options.TitleValue()), + Body: ghString(options.BodyValue()), AssigneeIDs: ghIds(assigneeIds), - LabelIDs: ghIds(labelIds), ProjectIDs: ghIds(projectIds), MilestoneID: ghId(milestoneId), } @@ -52,23 +79,65 @@ func UpdateIssue(httpClient *http.Client, repo ghrepo.Interface, id string, isPR return updatePullRequest(httpClient, repo, params) } - return updateIssue(httpClient, repo, githubv4.UpdateIssueInput{ + params := githubv4.UpdateIssueInput{ ID: id, - Title: title, - Body: body, + Title: ghString(options.TitleValue()), + Body: ghString(options.BodyValue()), AssigneeIDs: ghIds(assigneeIds), - LabelIDs: ghIds(labelIds), ProjectIDs: ghIds(projectIds), MilestoneID: ghId(milestoneId), - }) + } + return updateIssue(httpClient, repo, params) +} + +func dirtyExcludingLabels(e Editable) bool { + return e.Title.Edited || + e.Body.Edited || + e.Base.Edited || + e.Reviewers.Edited || + e.Assignees.Edited || + e.Projects.Edited || + e.Milestone.Edited +} + +func addLabels(httpClient *http.Client, id string, repo ghrepo.Interface, labels []string) error { + params := githubv4.AddLabelsToLabelableInput{ + LabelableID: id, + LabelIDs: *ghIds(&labels), + } + + var mutation struct { + AddLabelsToLabelable struct { + Typename string `graphql:"__typename"` + } `graphql:"addLabelsToLabelable(input: $input)"` + } + + variables := map[string]interface{}{"input": params} + gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient) + return gql.MutateNamed(context.Background(), "LabelAdd", &mutation, variables) +} + +func removeLabels(httpClient *http.Client, id string, repo ghrepo.Interface, labels []string) error { + params := githubv4.RemoveLabelsFromLabelableInput{ + LabelableID: id, + LabelIDs: *ghIds(&labels), + } + + var mutation struct { + RemoveLabelsFromLabelable struct { + Typename string `graphql:"__typename"` + } `graphql:"removeLabelsFromLabelable(input: $input)"` + } + + variables := map[string]interface{}{"input": params} + gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient) + return gql.MutateNamed(context.Background(), "LabelRemove", &mutation, variables) } func updateIssue(httpClient *http.Client, repo ghrepo.Interface, params githubv4.UpdateIssueInput) error { var mutation struct { UpdateIssue struct { - Issue struct { - ID string - } + Typename string `graphql:"__typename"` } `graphql:"updateIssue(input: $input)"` } variables := map[string]interface{}{"input": params} @@ -79,9 +148,7 @@ func updateIssue(httpClient *http.Client, repo ghrepo.Interface, params githubv4 func updatePullRequest(httpClient *http.Client, repo ghrepo.Interface, params githubv4.UpdatePullRequestInput) error { var mutation struct { UpdatePullRequest struct { - PullRequest struct { - ID string - } + Typename string `graphql:"__typename"` } `graphql:"updatePullRequest(input: $input)"` } variables := map[string]interface{}{"input": params}