From 3c443efbed4735ac331b3354ff164ae77b7d0495 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 12 Jan 2022 21:37:24 +0100 Subject: [PATCH] Improve Survey prompt stubber for tests Both SurveyAsk and SurveyAskOne methods now share the same sets of stubs, making it possible to change which of these methods is used in the implementation without breaking tests. A new method `AskStubber.StubPrompt("")` is added as test helper to supersede old Stub and StubOne methods. The new helper matches on prompt messages rather than on field names, enabling tests to be written based on what the user would see rather than coupling to implementation details. The new stubber also allows verifying whether a Select or MultiSelect was rendered with the expected set of options. Furthermore, if a stubbed value is not present among those options, the stubber will panic instead of continuing normally. Stubbed Selects with an int instead of a string target receiver are now transparently handled. The values for Select stubs are always strings in tests, but the stubber will write an int answer if the receiver expects one as a selected index instead of a selected string value. Lastly, this set of changes improves test resiliency since the stubs are now matched based on prompt message (or field name for legacy stubs created with Stub) instead of sequentially, enabling the implementation to reorder the prompts without breaking existing tests. --- pkg/cmd/pr/shared/survey_test.go | 2 +- pkg/prompt/stubber.go | 196 +++++++++++++++++++++---------- 2 files changed, 137 insertions(+), 61 deletions(-) diff --git a/pkg/cmd/pr/shared/survey_test.go b/pkg/cmd/pr/shared/survey_test.go index a28d96198..b0daa59f0 100644 --- a/pkg/cmd/pr/shared/survey_test.go +++ b/pkg/cmd/pr/shared/survey_test.go @@ -71,7 +71,7 @@ func TestMetadataSurvey_selectAll(t *testing.T) { }, { Name: "milestone", - Value: []string{"(none)"}, + Value: "(none)", }, }) diff --git a/pkg/prompt/stubber.go b/pkg/prompt/stubber.go index be920cd25..6ad992d02 100644 --- a/pkg/prompt/stubber.go +++ b/pkg/prompt/stubber.go @@ -2,19 +2,15 @@ package prompt import ( "fmt" - "reflect" + "strings" "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/core" + "github.com/cli/cli/v2/pkg/surveyext" ) type AskStubber struct { - Asks [][]*survey.Question - AskOnes []*survey.Prompt - Count int - OneCount int - Stubs [][]*QuestionStub - StubOnes []*PromptStub + stubs []*QuestionStub } func InitAskStubber() (*AskStubber, func()) { @@ -22,53 +18,104 @@ func InitAskStubber() (*AskStubber, func()) { origSurveyAskOne := SurveyAskOne as := AskStubber{} - SurveyAskOne = func(p survey.Prompt, response interface{}, opts ...survey.AskOpt) error { - as.AskOnes = append(as.AskOnes, &p) - count := as.OneCount - as.OneCount += 1 - if count >= len(as.StubOnes) { - panic(fmt.Sprintf("more asks than stubs. most recent call: %v", p)) - } - stubbedPrompt := as.StubOnes[count] - if stubbedPrompt.Default { - // TODO this is failing for basic AskOne invocations with a string result. - defaultValue := reflect.ValueOf(p).Elem().FieldByName("Default") - _ = core.WriteAnswer(response, "", defaultValue) - } else { - _ = core.WriteAnswer(response, "", stubbedPrompt.Value) + answerFromStub := func(p survey.Prompt, fieldName string, response interface{}) error { + var message string + var defaultValue interface{} + var options []string + switch pt := p.(type) { + case *survey.Confirm: + message = pt.Message + defaultValue = pt.Default + case *survey.Input: + message = pt.Message + defaultValue = pt.Default + case *survey.Select: + message = pt.Message + options = pt.Options + case *survey.MultiSelect: + message = pt.Message + options = pt.Options + case *survey.Password: + message = pt.Message + case *surveyext.GhEditor: + message = pt.Message + defaultValue = pt.Default + default: + panic(fmt.Sprintf("prompt type %T is not supported by the stubber", pt)) } + var stub *QuestionStub + for _, s := range as.stubs { + if !s.matched && (s.message == "" && strings.EqualFold(s.Name, fieldName) || s.message == message) { + stub = s + stub.matched = true + break + } + } + if stub == nil { + panic(fmt.Sprintf("no prompt stub for %q", message)) + } + + if len(stub.options) > 0 { + if err := compareOptions(stub.options, options); err != nil { + panic(fmt.Sprintf("options mismatch for %q: %v", message, err)) + } + } + + userValue := stub.Value + + if stringValue, ok := stub.Value.(string); ok && len(options) > 0 { + foundIndex := -1 + for i, o := range options { + if o == stringValue { + foundIndex = i + break + } + } + if foundIndex < 0 { + panic(fmt.Sprintf("answer %q not found in options for %q: %v", stringValue, message, options)) + } + userValue = core.OptionAnswer{ + Value: stringValue, + Index: foundIndex, + } + } + + if stub.Default { + if defaultIndex, ok := defaultValue.(int); ok && len(options) > 0 { + userValue = core.OptionAnswer{ + Value: options[defaultIndex], + Index: defaultIndex, + } + } else if defaultValue == nil && len(options) > 0 { + userValue = core.OptionAnswer{ + Value: options[0], + Index: 0, + } + } else { + userValue = defaultValue + } + } + + if err := core.WriteAnswer(response, fieldName, userValue); err != nil { + return fmt.Errorf("AskStubber failed writing the answer for field %q: %w", fieldName, err) + } return nil } + SurveyAskOne = func(p survey.Prompt, response interface{}, opts ...survey.AskOpt) error { + return answerFromStub(p, "", response) + } + SurveyAsk = func(qs []*survey.Question, response interface{}, opts ...survey.AskOpt) error { - as.Asks = append(as.Asks, qs) - count := as.Count - as.Count += 1 - if count >= len(as.Stubs) { - panic(fmt.Sprintf("more asks than stubs. most recent call: %#v", qs)) - } - - // actually set response - stubbedQuestions := as.Stubs[count] - if len(stubbedQuestions) != len(qs) { - panic(fmt.Sprintf("asked questions: %d; stubbed questions: %d", len(qs), len(stubbedQuestions))) - } - for i, sq := range stubbedQuestions { - q := qs[i] - if q.Name != sq.Name { - panic(fmt.Sprintf("stubbed question mismatch: %s != %s", q.Name, sq.Name)) - } - if sq.Default { - defaultValue := reflect.ValueOf(q.Prompt).Elem().FieldByName("Default") - _ = core.WriteAnswer(response, q.Name, defaultValue) - } else { - _ = core.WriteAnswer(response, q.Name, sq.Value) + for _, q := range qs { + if err := answerFromStub(q.Prompt, q.Name, response); err != nil { + return err } } - return nil } + teardown := func() { SurveyAsk = origSurveyAsk SurveyAskOne = origSurveyAskOne @@ -76,30 +123,59 @@ func InitAskStubber() (*AskStubber, func()) { return &as, teardown } -type PromptStub struct { - Value interface{} - Default bool -} - type QuestionStub struct { Name string Value interface{} Default bool + + matched bool + message string + options []string } +// AssertOptions asserts the options presented to the user in Selects and MultiSelects. +func (s *QuestionStub) AssertOptions(opts []string) *QuestionStub { + s.options = opts + return s +} + +// AnswerWith defines an answer for the given stub. +func (s *QuestionStub) AnswerWith(v interface{}) *QuestionStub { + s.Value = v + return s +} + +// AnswerDefault marks the current stub to be answered with the default value for the prompt question. +func (s *QuestionStub) AnswerDefault() *QuestionStub { + s.Default = true + return s +} + +// Deprecated: use StubPrompt func (as *AskStubber) StubOne(value interface{}) { - as.StubOnes = append(as.StubOnes, &PromptStub{ - Value: value, - }) -} - -func (as *AskStubber) StubOneDefault() { - as.StubOnes = append(as.StubOnes, &PromptStub{ - Default: true, - }) + as.Stub([]*QuestionStub{{Value: value}}) } +// Deprecated: use StubPrompt func (as *AskStubber) Stub(stubbedQuestions []*QuestionStub) { - // A call to .Ask takes a list of questions; a stub is then a list of questions in the same order. - as.Stubs = append(as.Stubs, stubbedQuestions) + as.stubs = append(as.stubs, stubbedQuestions...) +} + +// StubPrompt records a stub for an interactive prompt matched by its message. +func (as *AskStubber) StubPrompt(msg string) *QuestionStub { + stub := &QuestionStub{message: msg} + as.stubs = append(as.stubs, stub) + return stub +} + +func compareOptions(expected, got []string) error { + if len(expected) != len(got) { + return fmt.Errorf("expected %v, got %v (length mismatch)", expected, got) + } + for i, v := range expected { + if v != got[i] { + return fmt.Errorf("expected %v, got %v", expected, got) + } + } + return nil }