diff --git a/api/queries_repo.go b/api/queries_repo.go index 54f88a7e1..b60dfaa92 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -464,6 +464,28 @@ func (m *RepoMetadataResult) MilestoneToID(title string) (string, error) { return "", errors.New("not found") } +func (m *RepoMetadataResult) Merge(m2 *RepoMetadataResult) { + if len(m2.AssignableUsers) > 0 || len(m.AssignableUsers) == 0 { + m.AssignableUsers = m2.AssignableUsers + } + + if len(m2.Teams) > 0 || len(m.Teams) == 0 { + m.Teams = m2.Teams + } + + if len(m2.Labels) > 0 || len(m.Labels) == 0 { + m.Labels = m2.Labels + } + + if len(m2.Projects) > 0 || len(m.Projects) == 0 { + m.Projects = m2.Projects + } + + if len(m2.Milestones) > 0 || len(m.Milestones) == 0 { + m.Milestones = m2.Milestones + } +} + type RepoMetadataInput struct { Assignees bool Reviewers bool diff --git a/pkg/cmd/issue/create/create.go b/pkg/cmd/issue/create/create.go index 8df05d723..f8ee73abb 100644 --- a/pkg/cmd/issue/create/create.go +++ b/pkg/cmd/issue/create/create.go @@ -206,7 +206,13 @@ func createRun(opts *CreateOptions) (err error) { } if action == prShared.MetadataAction { - err = prShared.MetadataSurvey(opts.IO, apiClient, baseRepo, &tb) + fetcher := &prShared.MetadataFetcher{ + IO: opts.IO, + APIClient: apiClient, + Repo: baseRepo, + State: &tb, + } + err = prShared.MetadataSurvey(opts.IO, baseRepo, fetcher, &tb) if err != nil { return } diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 216f79d3e..db623acff 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -259,7 +259,13 @@ func createRun(opts *CreateOptions) (err error) { } if action == shared.MetadataAction { - err = shared.MetadataSurvey(opts.IO, client, ctx.BaseRepo, state) + fetcher := &shared.MetadataFetcher{ + IO: opts.IO, + APIClient: client, + Repo: ctx.BaseRepo, + State: state, + } + err = shared.MetadataSurvey(opts.IO, ctx.BaseRepo, fetcher, state) if err != nil { return } @@ -380,7 +386,7 @@ func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadata if opts.Autofill || !opts.TitleProvided || !opts.BodyProvided { err := initDefaultTitleBody(ctx, state) - if err != nil { + if err != nil && opts.Autofill { return nil, fmt.Errorf("could not compute title or body defaults: %w", err) } } diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 9edb9e9e7..6efdf9494 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -37,25 +37,53 @@ func WithPrAndIssueQueryParams(baseURL string, state IssueMetadataState) (string return u.String(), nil } +// Ensure that tb.MetadataResult object exists and contains enough pre-fetched API data to be able +// to resolve all object listed in tb to GraphQL IDs. +func fillMetadata(client *api.Client, baseRepo ghrepo.Interface, tb *IssueMetadataState) error { + resolveInput := api.RepoResolveInput{} + + if len(tb.Assignees) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.AssignableUsers) == 0) { + resolveInput.Assignees = tb.Assignees + } + + if len(tb.Reviewers) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.AssignableUsers) == 0) { + resolveInput.Reviewers = tb.Reviewers + } + + if len(tb.Labels) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Labels) == 0) { + resolveInput.Labels = tb.Labels + } + + if len(tb.Projects) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Projects) == 0) { + resolveInput.Projects = tb.Projects + } + + if len(tb.Milestones) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Milestones) == 0) { + resolveInput.Milestones = tb.Milestones + } + + metadataResult, err := api.RepoResolveMetadataIDs(client, baseRepo, resolveInput) + if err != nil { + return err + } + + if tb.MetadataResult == nil { + tb.MetadataResult = metadataResult + } else { + tb.MetadataResult.Merge(metadataResult) + } + + return nil +} + func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, params map[string]interface{}, tb *IssueMetadataState) error { if !tb.HasMetadata() { return nil } - if tb.MetadataResult == nil { - resolveInput := api.RepoResolveInput{ - Reviewers: tb.Reviewers, - Assignees: tb.Assignees, - Labels: tb.Labels, - Projects: tb.Projects, - Milestones: tb.Milestones, - } - - var err error - tb.MetadataResult, err = api.RepoResolveMetadataIDs(client, baseRepo, resolveInput) - if err != nil { - return err - } + err := fillMetadata(client, baseRepo, tb) + if err != nil { + return err } assigneeIDs, err := tb.MetadataResult.MembersToIDs(tb.Assignees) diff --git a/pkg/cmd/pr/shared/survey.go b/pkg/cmd/pr/shared/survey.go index 4e4a72855..8ef40d047 100644 --- a/pkg/cmd/pr/shared/survey.go +++ b/pkg/cmd/pr/shared/survey.go @@ -12,7 +12,6 @@ import ( "github.com/cli/cli/pkg/iostreams" "github.com/cli/cli/pkg/prompt" "github.com/cli/cli/pkg/surveyext" - "github.com/cli/cli/utils" ) type Action int @@ -196,7 +195,26 @@ func TitleSurvey(state *IssueMetadataState) error { return nil } -func MetadataSurvey(io *iostreams.IOStreams, client *api.Client, baseRepo ghrepo.Interface, state *IssueMetadataState) error { +type MetadataFetcher struct { + IO *iostreams.IOStreams + APIClient *api.Client + Repo ghrepo.Interface + State *IssueMetadataState +} + +func (mf *MetadataFetcher) RepoMetadataFetch(input api.RepoMetadataInput) (*api.RepoMetadataResult, error) { + mf.IO.StartProgressIndicator() + metadataResult, err := api.RepoMetadata(mf.APIClient, mf.Repo, input) + mf.IO.StopProgressIndicator() + mf.State.MetadataResult = metadataResult + return metadataResult, err +} + +type RepoMetadataFetcher interface { + RepoMetadataFetch(api.RepoMetadataInput) (*api.RepoMetadataResult, error) +} + +func MetadataSurvey(io *iostreams.IOStreams, baseRepo ghrepo.Interface, fetcher RepoMetadataFetcher, state *IssueMetadataState) error { isChosen := func(m string) bool { for _, c := range state.Metadata { if m == c { @@ -234,42 +252,32 @@ func MetadataSurvey(io *iostreams.IOStreams, client *api.Client, baseRepo ghrepo Projects: isChosen("Projects"), Milestones: isChosen("Milestone"), } - s := utils.Spinner(io.ErrOut) - utils.StartSpinner(s) - state.MetadataResult, err = api.RepoMetadata(client, baseRepo, metadataInput) - utils.StopSpinner(s) + metadataResult, err := fetcher.RepoMetadataFetch(metadataInput) if err != nil { return fmt.Errorf("error fetching metadata options: %w", err) } var users []string - for _, u := range state.MetadataResult.AssignableUsers { + for _, u := range metadataResult.AssignableUsers { users = append(users, u.Login) } var teams []string - for _, t := range state.MetadataResult.Teams { + for _, t := range metadataResult.Teams { teams = append(teams, fmt.Sprintf("%s/%s", baseRepo.RepoOwner(), t.Slug)) } var labels []string - for _, l := range state.MetadataResult.Labels { + for _, l := range metadataResult.Labels { labels = append(labels, l.Name) } var projects []string - for _, l := range state.MetadataResult.Projects { + for _, l := range metadataResult.Projects { projects = append(projects, l.Name) } milestones := []string{noMilestone} - for _, m := range state.MetadataResult.Milestones { + for _, m := range metadataResult.Milestones { milestones = append(milestones, m.Title) } - type metadataValues struct { - Reviewers []string - Assignees []string - Labels []string - Projects []string - Milestone string - } var mqs []*survey.Question if isChosen("Reviewers") { if len(users) > 0 || len(teams) > 0 { @@ -345,17 +353,38 @@ func MetadataSurvey(io *iostreams.IOStreams, client *api.Client, baseRepo ghrepo fmt.Fprintln(io.ErrOut, "warning: no milestones in the repository") } } - values := metadataValues{} + + values := struct { + Reviewers []string + Assignees []string + Labels []string + Projects []string + Milestone string + }{} + err = prompt.SurveyAsk(mqs, &values, survey.WithKeepFilter(true)) if err != nil { return fmt.Errorf("could not prompt: %w", err) } - state.Reviewers = values.Reviewers - state.Assignees = values.Assignees - state.Labels = values.Labels - state.Projects = values.Projects - if values.Milestone != "" && values.Milestone != noMilestone { - state.Milestones = []string{values.Milestone} + + if isChosen("Reviewers") { + state.Reviewers = values.Reviewers + } + if isChosen("Assignees") { + state.Assignees = values.Assignees + } + if isChosen("Labels") { + state.Labels = values.Labels + } + if isChosen("Projects") { + state.Projects = values.Projects + } + if isChosen("Milestone") { + if values.Milestone != "" && values.Milestone != noMilestone { + state.Milestones = []string{values.Milestone} + } else { + state.Milestones = []string{} + } } return nil diff --git a/pkg/cmd/pr/shared/survey_test.go b/pkg/cmd/pr/shared/survey_test.go new file mode 100644 index 000000000..a500040d3 --- /dev/null +++ b/pkg/cmd/pr/shared/survey_test.go @@ -0,0 +1,144 @@ +package shared + +import ( + "testing" + + "github.com/cli/cli/api" + "github.com/cli/cli/internal/ghrepo" + "github.com/cli/cli/pkg/iostreams" + "github.com/cli/cli/pkg/prompt" + "github.com/stretchr/testify/assert" +) + +type metadataFetcher struct { + metadataResult *api.RepoMetadataResult +} + +func (mf *metadataFetcher) RepoMetadataFetch(input api.RepoMetadataInput) (*api.RepoMetadataResult, error) { + return mf.metadataResult, nil +} + +func TestMetadataSurvey_selectAll(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + repo := ghrepo.New("OWNER", "REPO") + + fetcher := &metadataFetcher{ + metadataResult: &api.RepoMetadataResult{ + AssignableUsers: []api.RepoAssignee{ + {Login: "hubot"}, + {Login: "monalisa"}, + }, + Labels: []api.RepoLabel{ + {Name: "help wanted"}, + {Name: "good first issue"}, + }, + Projects: []api.RepoProject{ + {Name: "Huge Refactoring"}, + {Name: "The road to 1.0"}, + }, + Milestones: []api.RepoMilestone{ + {Title: "1.2 patch release"}, + }, + }, + } + + as, restoreAsk := prompt.InitAskStubber() + defer restoreAsk() + + as.Stub([]*prompt.QuestionStub{ + { + Name: "metadata", + Value: []string{"Labels", "Projects", "Assignees", "Reviewers", "Milestone"}, + }, + }) + as.Stub([]*prompt.QuestionStub{ + { + Name: "reviewers", + Value: []string{"monalisa"}, + }, + { + Name: "assignees", + Value: []string{"hubot"}, + }, + { + Name: "labels", + Value: []string{"good first issue"}, + }, + { + Name: "projects", + Value: []string{"The road to 1.0"}, + }, + { + Name: "milestone", + Value: []string{"(none)"}, + }, + }) + + state := &IssueMetadataState{ + Assignees: []string{"hubot"}, + } + err := MetadataSurvey(io, repo, fetcher, state) + assert.NoError(t, err) + + assert.Equal(t, "", stdout.String()) + assert.Equal(t, "", stderr.String()) + + assert.Equal(t, []string{"hubot"}, state.Assignees) + assert.Equal(t, []string{"monalisa"}, state.Reviewers) + assert.Equal(t, []string{"good first issue"}, state.Labels) + assert.Equal(t, []string{"The road to 1.0"}, state.Projects) + assert.Equal(t, []string{}, state.Milestones) +} + +func TestMetadataSurvey_keepExisting(t *testing.T) { + io, _, stdout, stderr := iostreams.Test() + + repo := ghrepo.New("OWNER", "REPO") + + fetcher := &metadataFetcher{ + metadataResult: &api.RepoMetadataResult{ + Labels: []api.RepoLabel{ + {Name: "help wanted"}, + {Name: "good first issue"}, + }, + Projects: []api.RepoProject{ + {Name: "Huge Refactoring"}, + {Name: "The road to 1.0"}, + }, + }, + } + + as, restoreAsk := prompt.InitAskStubber() + defer restoreAsk() + + as.Stub([]*prompt.QuestionStub{ + { + Name: "metadata", + Value: []string{"Labels", "Projects"}, + }, + }) + as.Stub([]*prompt.QuestionStub{ + { + Name: "labels", + Value: []string{"good first issue"}, + }, + { + Name: "projects", + Value: []string{"The road to 1.0"}, + }, + }) + + state := &IssueMetadataState{ + Assignees: []string{"hubot"}, + } + err := MetadataSurvey(io, repo, fetcher, state) + assert.NoError(t, err) + + assert.Equal(t, "", stdout.String()) + assert.Equal(t, "", stderr.String()) + + assert.Equal(t, []string{"hubot"}, state.Assignees) + assert.Equal(t, []string{"good first issue"}, state.Labels) + assert.Equal(t, []string{"The road to 1.0"}, state.Projects) +} diff --git a/pkg/prompt/stubber.go b/pkg/prompt/stubber.go index a3302c3f9..be920cd25 100644 --- a/pkg/prompt/stubber.go +++ b/pkg/prompt/stubber.go @@ -51,6 +51,9 @@ func InitAskStubber() (*AskStubber, func()) { // actually set response stubbedQuestions := as.Stubs[count] + if len(stubbedQuestions) != len(qs) { + panic(fmt.Sprintf("asked questions: %d; stubbed questions: %d", len(qs), len(stubbedQuestions))) + } for i, sq := range stubbedQuestions { q := qs[i] if q.Name != sq.Name {