From a47ee660a79cc578e3c860015e2329fcecfe84be Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Fri, 5 Feb 2021 11:46:58 -0800 Subject: [PATCH] Pr edit command --- api/queries_pr.go | 28 ++ pkg/cmd/issue/edit/edit.go | 127 ++++------ pkg/cmd/issue/edit/edit_test.go | 20 +- pkg/cmd/pr/edit/edit.go | 259 +++++++++++++++++++ pkg/cmd/pr/edit/edit_test.go | 435 ++++++++++++++++++++++++++++++++ pkg/cmd/pr/pr.go | 2 + pkg/cmd/pr/shared/editable.go | 254 ++++++++++++++----- 7 files changed, 967 insertions(+), 158 deletions(-) create mode 100644 pkg/cmd/pr/edit/edit.go create mode 100644 pkg/cmd/pr/edit/edit_test.go diff --git a/api/queries_pr.go b/api/queries_pr.go index 53935ccbe..e91216479 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -835,6 +835,34 @@ func CreatePullRequest(client *Client, repo *Repository, params map[string]inter return pr, nil } +func UpdatePullRequest(client *Client, repo ghrepo.Interface, params githubv4.UpdatePullRequestInput) error { + var mutation struct { + UpdatePullRequest struct { + PullRequest struct { + ID string + } + } `graphql:"updatePullRequest(input: $input)"` + } + variables := map[string]interface{}{"input": params} + gql := graphQLClient(client.http, repo.RepoHost()) + err := gql.MutateNamed(context.Background(), "PullRequestUpdate", &mutation, variables) + return err +} + +func UpdatePullRequestReviews(client *Client, repo ghrepo.Interface, params githubv4.RequestReviewsInput) error { + var mutation struct { + RequestReviews struct { + PullRequest struct { + ID string + } + } `graphql:"requestReviews(input: $input)"` + } + variables := map[string]interface{}{"input": params} + gql := graphQLClient(client.http, repo.RepoHost()) + err := gql.MutateNamed(context.Background(), "PullRequestUpdateRequestReviews", &mutation, variables) + return err +} + func isBlank(v interface{}) bool { switch vv := v.(type) { case string: diff --git a/pkg/cmd/issue/edit/edit.go b/pkg/cmd/issue/edit/edit.go index 856537e3e..c209c8fc7 100644 --- a/pkg/cmd/issue/edit/edit.go +++ b/pkg/cmd/issue/edit/edit.go @@ -22,14 +22,14 @@ type EditOptions struct { BaseRepo func() (ghrepo.Interface, error) DetermineEditor func() (string, error) - FieldsToEditSurvey func(*prShared.EditableOptions) error - EditableSurvey func(string, *prShared.EditableOptions) error - FetchOptions func(*api.Client, ghrepo.Interface, *prShared.EditableOptions) error + FieldsToEditSurvey func(*prShared.Editable) error + EditFieldsSurvey func(*prShared.Editable, string) error + FetchOptions func(*api.Client, ghrepo.Interface, *prShared.Editable) error SelectorArg string Interactive bool - prShared.EditableOptions + prShared.Editable } func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Command { @@ -38,7 +38,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman HttpClient: f.HttpClient, DetermineEditor: func() (string, error) { return cmdutil.DetermineEditor(f.Config) }, FieldsToEditSurvey: prShared.FieldsToEditSurvey, - EditableSurvey: prShared.EditableSurvey, + EditFieldsSurvey: prShared.EditFieldsSurvey, FetchOptions: prShared.FetchOptions, } @@ -62,25 +62,25 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman flags := cmd.Flags() if flags.Changed("title") { - opts.EditableOptions.TitleEdited = true + opts.Editable.TitleEdited = true } if flags.Changed("body") { - opts.EditableOptions.BodyEdited = true + opts.Editable.BodyEdited = true } if flags.Changed("assignee") { - opts.EditableOptions.AssigneesEdited = true + opts.Editable.AssigneesEdited = true } if flags.Changed("label") { - opts.EditableOptions.LabelsEdited = true + opts.Editable.LabelsEdited = true } if flags.Changed("project") { - opts.EditableOptions.ProjectsEdited = true + opts.Editable.ProjectsEdited = true } if flags.Changed("milestone") { - opts.EditableOptions.MilestoneEdited = true + opts.Editable.MilestoneEdited = true } - if !opts.EditableOptions.Dirty() { + if !opts.Editable.Dirty() { opts.Interactive = true } @@ -96,12 +96,12 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman }, } - cmd.Flags().StringVarP(&opts.EditableOptions.Title, "title", "t", "", "Revise the issue title.") - cmd.Flags().StringVarP(&opts.EditableOptions.Body, "body", "b", "", "Revise the issue body.") - cmd.Flags().StringSliceVarP(&opts.EditableOptions.Assignees, "assignee", "a", nil, "Set assigned people by their `login`. Use \"@me\" to self-assign.") - cmd.Flags().StringSliceVarP(&opts.EditableOptions.Labels, "label", "l", nil, "Set the issue labels by `name`") - cmd.Flags().StringSliceVarP(&opts.EditableOptions.Projects, "project", "p", nil, "Set the projects the issue belongs to by `name`") - cmd.Flags().StringVarP(&opts.EditableOptions.Milestone, "milestone", "m", "", "Set the milestone the issue belongs to by `name`") + cmd.Flags().StringVarP(&opts.Editable.Title, "title", "t", "", "Revise the issue title.") + cmd.Flags().StringVarP(&opts.Editable.Body, "body", "b", "", "Revise the issue body.") + cmd.Flags().StringSliceVarP(&opts.Editable.Assignees, "assignee", "a", nil, "Set assigned people by their `login`. Use \"@me\" to self-assign.") + cmd.Flags().StringSliceVarP(&opts.Editable.Labels, "label", "l", nil, "Set the issue labels by `name`") + cmd.Flags().StringSliceVarP(&opts.Editable.Projects, "project", "p", nil, "Set the projects the issue belongs to by `name`") + cmd.Flags().StringVarP(&opts.Editable.Milestone, "milestone", "m", "", "Set the milestone the issue belongs to by `name`") return cmd } @@ -118,23 +118,23 @@ func editRun(opts *EditOptions) error { return err } - editOptions := opts.EditableOptions - editOptions.TitleDefault = issue.Title - editOptions.BodyDefault = issue.Body - editOptions.AssigneesDefault = issue.Assignees - editOptions.LabelsDefault = issue.Labels - editOptions.ProjectsDefault = issue.ProjectCards - editOptions.MilestoneDefault = issue.Milestone + editable := opts.Editable + editable.TitleDefault = issue.Title + editable.BodyDefault = issue.Body + editable.AssigneesDefault = issue.Assignees + editable.LabelsDefault = issue.Labels + editable.ProjectsDefault = issue.ProjectCards + editable.MilestoneDefault = issue.Milestone if opts.Interactive { - err = opts.FieldsToEditSurvey(&editOptions) + err = opts.FieldsToEditSurvey(&editable) if err != nil { return err } } opts.IO.StartProgressIndicator() - err = opts.FetchOptions(apiClient, repo, &editOptions) + err = opts.FetchOptions(apiClient, repo, &editable) opts.IO.StopProgressIndicator() if err != nil { return err @@ -145,14 +145,14 @@ func editRun(opts *EditOptions) error { if err != nil { return err } - err = opts.EditableSurvey(editorCommand, &editOptions) + err = opts.EditFieldsSurvey(&editable, editorCommand) if err != nil { return err } } opts.IO.StartProgressIndicator() - err = updateIssue(apiClient, repo, issue.ID, editOptions) + err = updateIssue(apiClient, repo, issue.ID, editable) opts.IO.StopProgressIndicator() if err != nil { return err @@ -163,61 +163,28 @@ func editRun(opts *EditOptions) error { return nil } -func updateIssue(client *api.Client, repo ghrepo.Interface, id string, options prShared.EditableOptions) error { - params := githubv4.UpdateIssueInput{ID: id} - if options.TitleEdited { - title := githubv4.String(options.Title) - params.Title = &title +func updateIssue(client *api.Client, repo ghrepo.Interface, id string, options prShared.Editable) error { + var err error + params := githubv4.UpdateIssueInput{ + ID: id, + Title: options.TitleParam(), + Body: options.BodyParam(), } - if options.BodyEdited { - body := githubv4.String(options.Body) - params.Body = &body + params.AssigneeIDs, err = options.AssigneesParam(client, repo) + if err != nil { + return err } - if options.AssigneesEdited { - meReplacer := prShared.NewMeReplacer(client, repo.RepoHost()) - assignees, err := meReplacer.ReplaceSlice(options.Assignees) - if err != nil { - return err - } - ids, err := options.Metadata.MembersToIDs(assignees) - if err != nil { - return err - } - assigneeIDs := make([]githubv4.ID, len(ids)) - for i, v := range ids { - assigneeIDs[i] = v - } - params.AssigneeIDs = &assigneeIDs + params.LabelIDs, err = options.LabelsParam() + if err != nil { + return err } - if options.LabelsEdited { - ids, err := options.Metadata.LabelsToIDs(options.Labels) - if err != nil { - return err - } - labelIDs := make([]githubv4.ID, len(ids)) - for i, v := range ids { - labelIDs[i] = v - } - params.LabelIDs = &labelIDs + params.ProjectIDs, err = options.ProjectsParam() + if err != nil { + return err } - if options.ProjectsEdited { - ids, err := options.Metadata.ProjectsToIDs(options.Projects) - if err != nil { - return err - } - projectIDs := make([]githubv4.ID, len(ids)) - for i, v := range ids { - projectIDs[i] = v - } - params.ProjectIDs = &projectIDs - } - if options.MilestoneEdited { - id, err := options.Metadata.MilestoneToID(options.Milestone) - if err != nil { - return err - } - milestoneID := githubv4.ID(id) - params.MilestoneID = &milestoneID + params.MilestoneID, err = options.MilestoneParam() + if err != nil { + return err } return api.IssueUpdate(client, repo, params) } diff --git a/pkg/cmd/issue/edit/edit_test.go b/pkg/cmd/issue/edit/edit_test.go index a5d090c51..b132db26f 100644 --- a/pkg/cmd/issue/edit/edit_test.go +++ b/pkg/cmd/issue/edit/edit_test.go @@ -41,7 +41,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --title test", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Title: "test", TitleEdited: true, }, @@ -53,7 +53,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --body test", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Body: "test", BodyEdited: true, }, @@ -65,7 +65,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --assignee monalisa,hubot", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Assignees: []string{"monalisa", "hubot"}, AssigneesEdited: true, }, @@ -77,7 +77,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --label feature,TODO,bug", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Labels: []string{"feature", "TODO", "bug"}, LabelsEdited: true, }, @@ -89,7 +89,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --project Cleanup,Roadmap", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Projects: []string{"Cleanup", "Roadmap"}, ProjectsEdited: true, }, @@ -101,7 +101,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --milestone GA", output: EditOptions{ SelectorArg: "23", - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Milestone: "GA", MilestoneEdited: true, }, @@ -144,7 +144,7 @@ func TestNewCmdEdit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.output.SelectorArg, gotOpts.SelectorArg) assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) - assert.Equal(t, tt.output.EditableOptions, gotOpts.EditableOptions) + assert.Equal(t, tt.output.Editable, gotOpts.Editable) }) } } @@ -162,7 +162,7 @@ func Test_editRun(t *testing.T) { input: &EditOptions{ SelectorArg: "123", Interactive: false, - EditableOptions: prShared.EditableOptions{ + Editable: prShared.Editable{ Title: "new title", TitleEdited: true, Body: "new body", @@ -190,7 +190,7 @@ func Test_editRun(t *testing.T) { input: &EditOptions{ SelectorArg: "123", Interactive: true, - FieldsToEditSurvey: func(eo *prShared.EditableOptions) error { + FieldsToEditSurvey: func(eo *prShared.Editable) error { eo.TitleEdited = true eo.BodyEdited = true eo.AssigneesEdited = true @@ -199,7 +199,7 @@ func Test_editRun(t *testing.T) { eo.MilestoneEdited = true return nil }, - EditableSurvey: func(_ string, eo *prShared.EditableOptions) error { + EditFieldsSurvey: func(eo *prShared.Editable, _ string) error { eo.Title = "new title" eo.Body = "new body" eo.Assignees = []string{"monalisa", "hubot"} diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go new file mode 100644 index 000000000..0bd5ce826 --- /dev/null +++ b/pkg/cmd/pr/edit/edit.go @@ -0,0 +1,259 @@ +package edit + +import ( + "errors" + "fmt" + "net/http" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/api" + "github.com/cli/cli/context" + "github.com/cli/cli/internal/config" + "github.com/cli/cli/internal/ghrepo" + shared "github.com/cli/cli/pkg/cmd/pr/shared" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/iostreams" + "github.com/shurcooL/githubv4" + "github.com/spf13/cobra" +) + +type EditOptions struct { + HttpClient func() (*http.Client, error) + IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) + Remotes func() (context.Remotes, error) + Branch func() (string, error) + + Surveyor Surveyor + Fetcher EditableOptionsFetcher + EditorRetriever EditorRetriever + + SelectorArg string + Interactive bool + + shared.Editable +} + +func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Command { + opts := &EditOptions{ + IO: f.IOStreams, + HttpClient: f.HttpClient, + Remotes: f.Remotes, + Branch: f.Branch, + Surveyor: surveyor{}, + Fetcher: fetcher{}, + EditorRetriever: editorRetriever{config: f.Config}, + } + + cmd := &cobra.Command{ + Use: "edit { | }", + Short: "Edit a pull request", + Example: heredoc.Doc(` + $ gh pr edit 23 --title "I found a bug" --body "Nothing works" + $ gh pr edit 23 --label "bug,help wanted" + $ gh pr edit 23 --label bug --label "help wanted" + $ gh pr edit 23 --reviewer monalisa,hubot --reviewer myorg/team-name + $ gh pr edit 23 --assignee monalisa,hubot + $ gh pr edit 23 --assignee @me + $ gh pr edit 23 --project "Roadmap" + `), + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + + opts.SelectorArg = args[0] + + flags := cmd.Flags() + if flags.Changed("title") { + opts.Editable.TitleEdited = true + } + if flags.Changed("body") { + opts.Editable.BodyEdited = true + } + if flags.Changed("reviewer") { + opts.Editable.ReviewersEdited = true + } + if flags.Changed("assignee") { + opts.Editable.AssigneesEdited = true + } + if flags.Changed("label") { + opts.Editable.LabelsEdited = true + } + if flags.Changed("project") { + opts.Editable.ProjectsEdited = true + } + if flags.Changed("milestone") { + opts.Editable.MilestoneEdited = true + } + + if !opts.Editable.Dirty() { + opts.Interactive = true + } + + if opts.Interactive && !opts.IO.CanPrompt() { + return &cmdutil.FlagError{Err: errors.New("--tile, --body, --reviewer, --assignee, --label, --project, or --milestone required when not running interactively")} + } + + if runF != nil { + return runF(opts) + } + + return editRun(opts) + }, + } + + cmd.Flags().StringVarP(&opts.Editable.Title, "title", "t", "", "Revise the pr title.") + cmd.Flags().StringVarP(&opts.Editable.Body, "body", "b", "", "Revise the pr body.") + cmd.Flags().StringSliceVarP(&opts.Editable.Reviewers, "reviewer", "r", nil, "Request reviews from people or teams by their `handle`") + cmd.Flags().StringSliceVarP(&opts.Editable.Assignees, "assignee", "a", nil, "Set assigned people by their `login`. Use \"@me\" to self-assign.") + cmd.Flags().StringSliceVarP(&opts.Editable.Labels, "label", "l", nil, "Set the pr labels by `name`") + cmd.Flags().StringSliceVarP(&opts.Editable.Projects, "project", "p", nil, "Set the projects the pr belongs to by `name`") + cmd.Flags().StringVarP(&opts.Editable.Milestone, "milestone", "m", "", "Set the milestone the pr belongs to by `name`") + + return cmd +} + +func editRun(opts *EditOptions) error { + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + apiClient := api.NewClientFromHTTP(httpClient) + + pr, repo, err := shared.PRFromArgs(apiClient, opts.BaseRepo, opts.Branch, opts.Remotes, opts.SelectorArg) + if err != nil { + return err + } + + editable := opts.Editable + editable.ReviewersAllowed = true + editable.TitleDefault = pr.Title + editable.BodyDefault = pr.Body + editable.ReviewersDefault = pr.ReviewRequests + editable.AssigneesDefault = pr.Assignees + editable.LabelsDefault = pr.Labels + editable.ProjectsDefault = pr.ProjectCards + editable.MilestoneDefault = pr.Milestone + + if opts.Interactive { + err = opts.Surveyor.FieldsToEdit(&editable) + if err != nil { + return err + } + } + + opts.IO.StartProgressIndicator() + err = opts.Fetcher.EditableOptionsFetch(apiClient, repo, &editable) + opts.IO.StopProgressIndicator() + if err != nil { + return err + } + + if opts.Interactive { + editorCommand, err := opts.EditorRetriever.Retrieve() + if err != nil { + return err + } + err = opts.Surveyor.EditFields(&editable, editorCommand) + if err != nil { + return err + } + } + + opts.IO.StartProgressIndicator() + err = updatePullRequest(apiClient, repo, pr.ID, editable) + opts.IO.StopProgressIndicator() + if err != nil { + return err + } + + fmt.Fprintln(opts.IO.Out, pr.URL) + + return nil +} + +func updatePullRequest(client *api.Client, repo ghrepo.Interface, id string, editable shared.Editable) error { + var err error + params := githubv4.UpdatePullRequestInput{ + PullRequestID: id, + Title: editable.TitleParam(), + Body: editable.BodyParam(), + } + params.AssigneeIDs, err = editable.AssigneesParam(client, repo) + if err != nil { + return err + } + params.LabelIDs, err = editable.LabelsParam() + if err != nil { + return err + } + params.ProjectIDs, err = editable.ProjectsParam() + if err != nil { + return err + } + params.MilestoneID, err = editable.MilestoneParam() + if err != nil { + return err + } + err = api.UpdatePullRequest(client, repo, params) + if err != nil { + return err + } + return updatePullRequestReviews(client, repo, id, editable) +} + +func updatePullRequestReviews(client *api.Client, repo ghrepo.Interface, id string, editable shared.Editable) error { + if !editable.ReviewersEdited { + return nil + } + userIds, teamIds, err := editable.ReviewersParams() + if err != nil { + return err + } + union := githubv4.Boolean(false) + reviewsRequestParams := githubv4.RequestReviewsInput{ + PullRequestID: id, + Union: &union, + UserIDs: userIds, + TeamIDs: teamIds, + } + return api.UpdatePullRequestReviews(client, repo, reviewsRequestParams) +} + +type Surveyor interface { + FieldsToEdit(*shared.Editable) error + EditFields(*shared.Editable, string) error +} + +type surveyor struct{} + +func (s surveyor) FieldsToEdit(editable *shared.Editable) error { + return shared.FieldsToEditSurvey(editable) +} + +func (s surveyor) EditFields(editable *shared.Editable, editorCmd string) error { + return shared.EditFieldsSurvey(editable, editorCmd) +} + +type EditableOptionsFetcher interface { + EditableOptionsFetch(*api.Client, ghrepo.Interface, *shared.Editable) error +} + +type fetcher struct{} + +func (f fetcher) EditableOptionsFetch(client *api.Client, repo ghrepo.Interface, opts *shared.Editable) error { + return shared.FetchOptions(client, repo, opts) +} + +type EditorRetriever interface { + Retrieve() (string, error) +} + +type editorRetriever struct { + config func() (config.Config, error) +} + +func (e editorRetriever) Retrieve() (string, error) { + return cmdutil.DetermineEditor(e.config) +} diff --git a/pkg/cmd/pr/edit/edit_test.go b/pkg/cmd/pr/edit/edit_test.go new file mode 100644 index 000000000..1351ff29a --- /dev/null +++ b/pkg/cmd/pr/edit/edit_test.go @@ -0,0 +1,435 @@ +package edit + +import ( + "bytes" + "net/http" + "testing" + + "github.com/cli/cli/api" + "github.com/cli/cli/internal/ghrepo" + shared "github.com/cli/cli/pkg/cmd/pr/shared" + "github.com/cli/cli/pkg/cmdutil" + "github.com/cli/cli/pkg/httpmock" + "github.com/cli/cli/pkg/iostreams" + "github.com/google/shlex" + "github.com/stretchr/testify/assert" +) + +func TestNewCmdEdit(t *testing.T) { + tests := []struct { + name string + input string + output EditOptions + wantsErr bool + }{ + { + name: "no argument", + input: "", + output: EditOptions{}, + wantsErr: true, + }, + { + name: "issue number argument", + input: "23", + output: EditOptions{ + SelectorArg: "23", + Interactive: true, + }, + wantsErr: false, + }, + { + name: "title flag", + input: "23 --title test", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Title: "test", + TitleEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "body flag", + input: "23 --body test", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Body: "test", + BodyEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "reviewer flag", + input: "23 --reviewer owner/team,monalisa", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Reviewers: []string{"owner/team", "monalisa"}, + ReviewersEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "assignee flag", + input: "23 --assignee monalisa,hubot", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Assignees: []string{"monalisa", "hubot"}, + AssigneesEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "label flag", + input: "23 --label feature,TODO,bug", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Labels: []string{"feature", "TODO", "bug"}, + LabelsEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "project flag", + input: "23 --project Cleanup,Roadmap", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Projects: []string{"Cleanup", "Roadmap"}, + ProjectsEdited: true, + }, + }, + wantsErr: false, + }, + { + name: "milestone flag", + input: "23 --milestone GA", + output: EditOptions{ + SelectorArg: "23", + Editable: shared.Editable{ + Milestone: "GA", + MilestoneEdited: true, + }, + }, + wantsErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + io, _, _, _ := iostreams.Test() + io.SetStdoutTTY(true) + io.SetStdinTTY(true) + io.SetStderrTTY(true) + + f := &cmdutil.Factory{ + IOStreams: io, + } + + argv, err := shlex.Split(tt.input) + assert.NoError(t, err) + + var gotOpts *EditOptions + cmd := NewCmdEdit(f, func(opts *EditOptions) error { + gotOpts = opts + return nil + }) + cmd.Flags().BoolP("help", "x", false, "") + + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + + _, err = cmd.ExecuteC() + if tt.wantsErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.output.SelectorArg, gotOpts.SelectorArg) + assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) + assert.Equal(t, tt.output.Editable, gotOpts.Editable) + }) + } +} + +func Test_editRun(t *testing.T) { + tests := []struct { + name string + input *EditOptions + httpStubs func(*testing.T, *httpmock.Registry) + stdout string + stderr string + }{ + { + name: "non-interactive", + input: &EditOptions{ + SelectorArg: "123", + Interactive: false, + Editable: shared.Editable{ + Title: "new title", + TitleEdited: true, + Body: "new body", + BodyEdited: true, + Reviewers: []string{"OWNER/core", "OWNER/external", "monalisa", "hubot"}, + ReviewersEdited: true, + Assignees: []string{"monalisa", "hubot"}, + AssigneesEdited: true, + Labels: []string{"feature", "TODO", "bug"}, + LabelsEdited: true, + Projects: []string{"Cleanup", "Roadmap"}, + ProjectsEdited: true, + Milestone: "GA", + MilestoneEdited: true, + }, + Fetcher: testFetcher{}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockPullRequestGet(t, reg) + mockRepoMetadata(t, reg, false) + mockPullRequestUpdate(t, reg) + mockPullRequestReviewersUpdate(t, reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + { + name: "non-interactive skip reviewers", + input: &EditOptions{ + SelectorArg: "123", + Interactive: false, + Editable: shared.Editable{ + Title: "new title", + TitleEdited: true, + Body: "new body", + BodyEdited: true, + Assignees: []string{"monalisa", "hubot"}, + AssigneesEdited: true, + Labels: []string{"feature", "TODO", "bug"}, + LabelsEdited: true, + Projects: []string{"Cleanup", "Roadmap"}, + ProjectsEdited: true, + Milestone: "GA", + MilestoneEdited: true, + }, + Fetcher: testFetcher{}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockPullRequestGet(t, reg) + mockRepoMetadata(t, reg, true) + mockPullRequestUpdate(t, reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + { + name: "interactive", + input: &EditOptions{ + SelectorArg: "123", + Interactive: true, + Surveyor: testSurveyor{}, + Fetcher: testFetcher{}, + EditorRetriever: testEditorRetriever{}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockPullRequestGet(t, reg) + mockRepoMetadata(t, reg, false) + mockPullRequestUpdate(t, reg) + mockPullRequestReviewersUpdate(t, reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + { + name: "interactive skip reviewers", + input: &EditOptions{ + SelectorArg: "123", + Interactive: true, + Surveyor: testSurveyor{skipReviewers: true}, + Fetcher: testFetcher{}, + EditorRetriever: testEditorRetriever{}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockPullRequestGet(t, reg) + mockRepoMetadata(t, reg, true) + mockPullRequestUpdate(t, reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + } + for _, tt := range tests { + io, _, stdout, stderr := iostreams.Test() + io.SetStdoutTTY(true) + io.SetStdinTTY(true) + io.SetStderrTTY(true) + + reg := &httpmock.Registry{} + defer reg.Verify(t) + tt.httpStubs(t, reg) + + httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } + baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil } + + tt.input.IO = io + tt.input.HttpClient = httpClient + tt.input.BaseRepo = baseRepo + + t.Run(tt.name, func(t *testing.T) { + err := editRun(tt.input) + assert.NoError(t, err) + assert.Equal(t, tt.stdout, stdout.String()) + assert.Equal(t, tt.stderr, stderr.String()) + }) + } +} + +func mockPullRequestGet(_ *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query PullRequestByNumber\b`), + httpmock.StringResponse(` + { "data": { "repository": { "pullRequest": { + "id": "456", + "number": 123, + "url": "https://github.com/OWNER/REPO/pull/123" + } } } }`), + ) +} + +func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry, skipReviewers bool) { + reg.Register( + httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.StringResponse(` + { "data": { "repository": { "assignableUsers": { + "nodes": [ + { "login": "hubot", "id": "HUBOTID" }, + { "login": "MonaLisa", "id": "MONAID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query RepositoryLabelList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "labels": { + "nodes": [ + { "name": "feature", "id": "FEATUREID" }, + { "name": "TODO", "id": "TODOID" }, + { "name": "bug", "id": "BUGID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query RepositoryMilestoneList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "milestones": { + "nodes": [ + { "title": "GA", "id": "GAID" }, + { "title": "Big One.oh", "id": "BIGONEID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projects": { + "nodes": [ + { "name": "Cleanup", "id": "CLEANUPID" }, + { "name": "Roadmap", "id": "ROADMAPID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projects": { + "nodes": [ + { "name": "Triage", "id": "TRIAGEID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + if !skipReviewers { + reg.Register( + httpmock.GraphQL(`query OrganizationTeamList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "teams": { + "nodes": [ + { "slug": "external", "id": "EXTERNALID" }, + { "slug": "core", "id": "COREID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + } +} + +func mockPullRequestUpdate(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation PullRequestUpdate\b`), + httpmock.GraphQLMutation(` + { "data": { "updatePullRequest": { "pullRequest": { + "id": "456" + } } } }`, + func(inputs map[string]interface{}) {}), + ) +} + +func mockPullRequestReviewersUpdate(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation PullRequestUpdateRequestReviews\b`), + httpmock.GraphQLMutation(` + { "data": { "requestReviews": { "pullRequest": { + "id": "456" + } } } }`, + func(inputs map[string]interface{}) {}), + ) +} + +type testFetcher struct{} +type testSurveyor struct { + skipReviewers bool +} +type testEditorRetriever struct{} + +func (f testFetcher) EditableOptionsFetch(client *api.Client, repo ghrepo.Interface, opts *shared.Editable) error { + return shared.FetchOptions(client, repo, opts) +} + +func (s testSurveyor) FieldsToEdit(e *shared.Editable) error { + e.TitleEdited = true + e.BodyEdited = true + if !s.skipReviewers { + e.ReviewersEdited = true + } + e.AssigneesEdited = true + e.LabelsEdited = true + e.ProjectsEdited = true + e.MilestoneEdited = true + return nil +} + +func (s testSurveyor) EditFields(e *shared.Editable, _ string) error { + e.Title = "new title" + e.Body = "new body" + if !s.skipReviewers { + e.Reviewers = []string{"monalisa", "hubot", "OWNER/core", "OWNER/external"} + } + e.Assignees = []string{"monalisa", "hubot"} + e.Labels = []string{"feature", "TODO", "bug"} + e.Projects = []string{"Cleanup", "Roadmap"} + e.Milestone = "GA" + return nil +} + +func (t testEditorRetriever) Retrieve() (string, error) { + return "vim", nil +} diff --git a/pkg/cmd/pr/pr.go b/pkg/cmd/pr/pr.go index f1981fabe..3be067636 100644 --- a/pkg/cmd/pr/pr.go +++ b/pkg/cmd/pr/pr.go @@ -8,6 +8,7 @@ import ( cmdComment "github.com/cli/cli/pkg/cmd/pr/comment" cmdCreate "github.com/cli/cli/pkg/cmd/pr/create" cmdDiff "github.com/cli/cli/pkg/cmd/pr/diff" + cmdEdit "github.com/cli/cli/pkg/cmd/pr/edit" cmdList "github.com/cli/cli/pkg/cmd/pr/list" cmdMerge "github.com/cli/cli/pkg/cmd/pr/merge" cmdReady "github.com/cli/cli/pkg/cmd/pr/ready" @@ -55,6 +56,7 @@ func NewCmdPR(f *cmdutil.Factory) *cobra.Command { cmd.AddCommand(cmdView.NewCmdView(f, nil)) cmd.AddCommand(cmdChecks.NewCmdChecks(f, nil)) cmd.AddCommand(cmdComment.NewCmdComment(f, nil)) + cmd.AddCommand(cmdEdit.NewCmdEdit(f, nil)) return cmd } diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index 36bd0387c..885d1459d 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -2,14 +2,16 @@ package shared import ( "fmt" + "strings" "github.com/AlecAivazis/survey/v2" "github.com/cli/cli/api" "github.com/cli/cli/internal/ghrepo" "github.com/cli/cli/pkg/surveyext" + "github.com/shurcooL/githubv4" ) -type EditableOptions struct { +type Editable struct { Title string TitleDefault string TitleEdited bool @@ -47,7 +49,7 @@ type EditableOptions struct { Metadata api.RepoMetadataResult } -func (e EditableOptions) Dirty() bool { +func (e Editable) Dirty() bool { return e.TitleEdited || e.BodyEdited || e.ReviewersEdited || @@ -57,48 +59,125 @@ func (e EditableOptions) Dirty() bool { e.MilestoneEdited } -func EditableSurvey(editorCommand string, options *EditableOptions) error { - if options.TitleEdited { - title, err := titleSurvey(options.TitleDefault) - if err != nil { - return err - } - options.Title = title +func (e Editable) TitleParam() *githubv4.String { + if !e.TitleEdited { + return nil } - if options.BodyEdited { - body, err := bodySurvey(options.BodyDefault, editorCommand) - if err != nil { - return err - } - options.Body = body + s := githubv4.String(e.Title) + return &s +} + +func (e Editable) BodyParam() *githubv4.String { + if !e.BodyEdited { + return nil } - if options.AssigneesEdited { - assignees, err := assigneesSurvey(options.AssigneesDefault, options.AssigneesOptions) - if err != nil { - return err - } - options.Assignees = assignees + s := githubv4.String(e.Body) + return &s +} + +func (e Editable) ReviewersParams() (*[]githubv4.ID, *[]githubv4.ID, error) { + if !e.ReviewersEdited { + return nil, nil, nil } - if options.LabelsEdited { - labels, err := labelsSurvey(options.LabelsDefault, options.LabelsOptions) - if err != nil { - return err + var userReviewers []string + var teamReviewers []string + for _, r := range e.Reviewers { + if strings.ContainsRune(r, '/') { + teamReviewers = append(teamReviewers, r) + } else { + userReviewers = append(userReviewers, r) } - options.Labels = labels } - if options.ProjectsEdited { - projects, err := projectsSurvey(options.ProjectsDefault, options.ProjectsOptions) - if err != nil { - return err - } - options.Projects = projects + userIds, err := toParams(userReviewers, e.Metadata.MembersToIDs) + if err != nil { + return nil, nil, err } - if options.MilestoneEdited { - milestone, err := milestoneSurvey(options.MilestoneDefault, options.MilestoneOptions) + teamIds, err := toParams(teamReviewers, e.Metadata.TeamsToIDs) + if err != nil { + return nil, nil, err + } + return userIds, teamIds, nil +} + +func (e Editable) AssigneesParam(client *api.Client, repo ghrepo.Interface) (*[]githubv4.ID, error) { + if !e.AssigneesEdited { + return nil, nil + } + meReplacer := NewMeReplacer(client, repo.RepoHost()) + assignees, err := meReplacer.ReplaceSlice(e.Assignees) + if err != nil { + return nil, err + } + return toParams(assignees, e.Metadata.MembersToIDs) +} + +func (e Editable) LabelsParam() (*[]githubv4.ID, error) { + if !e.LabelsEdited { + return nil, nil + } + return toParams(e.Labels, e.Metadata.LabelsToIDs) +} + +func (e Editable) ProjectsParam() (*[]githubv4.ID, error) { + if !e.ProjectsEdited { + return nil, nil + } + return toParams(e.Projects, e.Metadata.ProjectsToIDs) +} + +func (e Editable) MilestoneParam() (*githubv4.ID, error) { + if !e.MilestoneEdited { + return nil, nil + } + if e.Milestone == noMilestone || e.Milestone == "" { + return githubv4.NewID(nil), nil + } + return toParam(e.Milestone, e.Metadata.MilestoneToID) +} + +func EditFieldsSurvey(editable *Editable, editorCommand string) error { + var err error + if editable.TitleEdited { + editable.Title, err = titleSurvey(editable.TitleDefault) + if err != nil { + return err + } + } + if editable.BodyEdited { + editable.Body, err = bodySurvey(editable.BodyDefault, editorCommand) + if err != nil { + return err + } + } + if editable.ReviewersEdited { + editable.Reviewers, err = reviewersSurvey(editable.ReviewersDefault, editable.ReviewersOptions) + if err != nil { + return err + } + } + if editable.AssigneesEdited { + editable.Assignees, err = assigneesSurvey(editable.AssigneesDefault, editable.AssigneesOptions) + if err != nil { + return err + } + } + if editable.LabelsEdited { + editable.Labels, err = labelsSurvey(editable.LabelsDefault, editable.LabelsOptions) + if err != nil { + return err + } + } + if editable.ProjectsEdited { + editable.Projects, err = projectsSurvey(editable.ProjectsDefault, editable.ProjectsOptions) + if err != nil { + return err + } + } + if editable.MilestoneEdited { + editable.Milestone, err = milestoneSurvey(editable.MilestoneDefault, editable.MilestoneOptions) if err != nil { return err } - options.Milestone = milestone } confirm, err := confirmSurvey() if err != nil { @@ -111,7 +190,7 @@ func EditableSurvey(editorCommand string, options *EditableOptions) error { return nil } -func FieldsToEditSurvey(options *EditableOptions) error { +func FieldsToEditSurvey(editable *Editable) error { contains := func(s []string, str string) bool { for _, v := range s { if v == str { @@ -123,7 +202,7 @@ func FieldsToEditSurvey(options *EditableOptions) error { results := []string{} opts := []string{"Title", "Body"} - if options.ReviewersAllowed { + if editable.ReviewersAllowed { opts = append(opts, "Reviewers") } opts = append(opts, "Assignees", "Labels", "Projects", "Milestone") @@ -137,37 +216,37 @@ func FieldsToEditSurvey(options *EditableOptions) error { } if contains(results, "Title") { - options.TitleEdited = true + editable.TitleEdited = true } if contains(results, "Body") { - options.BodyEdited = true + editable.BodyEdited = true } if contains(results, "Reviewers") { - options.ReviewersEdited = true + editable.ReviewersEdited = true } if contains(results, "Assignees") { - options.AssigneesEdited = true + editable.AssigneesEdited = true } if contains(results, "Labels") { - options.LabelsEdited = true + editable.LabelsEdited = true } if contains(results, "Projects") { - options.ProjectsEdited = true + editable.ProjectsEdited = true } if contains(results, "Milestone") { - options.MilestoneEdited = true + editable.MilestoneEdited = true } return nil } -func FetchOptions(client *api.Client, repo ghrepo.Interface, options *EditableOptions) error { +func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) error { input := api.RepoMetadataInput{ - Reviewers: options.ReviewersEdited, - Assignees: options.AssigneesEdited, - Labels: options.LabelsEdited, - Projects: options.ProjectsEdited, - Milestones: options.MilestoneEdited, + Reviewers: editable.ReviewersEdited, + Assignees: editable.AssigneesEdited, + Labels: editable.LabelsEdited, + Projects: editable.ProjectsEdited, + Milestones: editable.MilestoneEdited, } metadata, err := api.RepoMetadata(client, repo, input) if err != nil { @@ -195,12 +274,12 @@ func FetchOptions(client *api.Client, repo ghrepo.Interface, options *EditableOp milestones = append(milestones, m.Title) } - options.Metadata = *metadata - options.ReviewersOptions = append(users, teams...) - options.AssigneesOptions = users - options.LabelsOptions = labels - options.ProjectsOptions = projects - options.MilestoneOptions = milestones + editable.Metadata = *metadata + editable.ReviewersOptions = append(users, teams...) + editable.AssigneesOptions = users + editable.LabelsOptions = labels + editable.ProjectsOptions = projects + editable.MilestoneOptions = milestones return nil } @@ -231,8 +310,26 @@ func bodySurvey(body, editorCommand string) (string, error) { return result, err } -func assigneesSurvey(assignees api.Assignees, assigneesOpts []string) ([]string, error) { - if len(assigneesOpts) == 0 { +func reviewersSurvey(reviewers api.ReviewRequests, opts []string) ([]string, error) { + if len(opts) == 0 { + return nil, nil + } + logins := []string{} + for _, a := range reviewers.Nodes { + logins = append(logins, a.RequestedReviewer.Login) + } + var results []string + q := &survey.MultiSelect{ + Message: "Reviewers", + Options: opts, + Default: logins, + } + err := survey.AskOne(q, &results) + return results, err +} + +func assigneesSurvey(assignees api.Assignees, opts []string) ([]string, error) { + if len(opts) == 0 { return nil, nil } logins := []string{} @@ -242,15 +339,15 @@ func assigneesSurvey(assignees api.Assignees, assigneesOpts []string) ([]string, var results []string q := &survey.MultiSelect{ Message: "Assignees", - Options: assigneesOpts, + Options: opts, Default: logins, } err := survey.AskOne(q, &results) return results, err } -func labelsSurvey(labels api.Labels, labelOpts []string) ([]string, error) { - if len(labelOpts) == 0 { +func labelsSurvey(labels api.Labels, opts []string) ([]string, error) { + if len(opts) == 0 { return nil, nil } names := []string{} @@ -260,15 +357,15 @@ func labelsSurvey(labels api.Labels, labelOpts []string) ([]string, error) { var results []string q := &survey.MultiSelect{ Message: "Labels", - Options: labelOpts, + Options: opts, Default: names, } err := survey.AskOne(q, &results) return results, err } -func projectsSurvey(projectCards api.ProjectCards, projectsOpts []string) ([]string, error) { - if len(projectsOpts) == 0 { +func projectsSurvey(projectCards api.ProjectCards, opts []string) ([]string, error) { + if len(opts) == 0 { return nil, nil } names := []string{} @@ -278,21 +375,21 @@ func projectsSurvey(projectCards api.ProjectCards, projectsOpts []string) ([]str var results []string q := &survey.MultiSelect{ Message: "Projects", - Options: projectsOpts, + Options: opts, Default: names, } err := survey.AskOne(q, &results) return results, err } -func milestoneSurvey(milestone api.Milestone, milestoneOpts []string) (string, error) { - if len(milestoneOpts) == 0 { +func milestoneSurvey(milestone api.Milestone, opts []string) (string, error) { + if len(opts) == 0 { return "", nil } var result string q := &survey.Select{ Message: "Milestone", - Options: milestoneOpts, + Options: opts, Default: milestone.Title, } err := survey.AskOne(q, &result) @@ -308,3 +405,24 @@ func confirmSurvey() (bool, error) { err := survey.AskOne(q, &result) return result, err } + +func toParams(s []string, mapper func([]string) ([]string, error)) (*[]githubv4.ID, error) { + ids, err := mapper(s) + if err != nil { + return nil, err + } + gIds := make([]githubv4.ID, len(ids)) + for i, v := range ids { + gIds[i] = v + } + return &gIds, nil +} + +func toParam(s string, mapper func(string) (string, error)) (*githubv4.ID, error) { + id, err := mapper(s) + if err != nil { + return nil, err + } + gId := githubv4.ID(id) + return &gId, nil +}