diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 32a412646..b2a66141a 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -101,6 +101,27 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, expectedIndex, selectValue) }) + t.Run("Select - invalid defaults are excluded from prompt", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValue := "foo" + options := []string{"1", "2"} + + go func() { + // Wait for prompt to appear without the invalid default value + _, err := console.ExpectString("Select a number \r\n") + require.NoError(t, err) + + // Select option 2 + _, err = console.SendLine("2") + require.NoError(t, err) + }() + + selectValue, err := p.Select("Select a number", dummyDefaultValue, options) + require.NoError(t, err) + assert.Equal(t, 1, selectValue) + }) + t.Run("MultiSelect", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -178,6 +199,31 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, expectedIndices, multiSelectValues) }) + t.Run("MultiSelect - invalid defaults are excluded from prompt", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValues := []string{"foo", "bar"} + options := []string{"1", "2"} + + go func() { + // Wait for prompt to appear without the invalid default values + _, err := console.ExpectString("Select a number \r\n") + require.NoError(t, err) + + // Not selecting anything will fail because there are no defaults. + _, err = console.SendLine("2") + require.NoError(t, err) + + // This confirms selections + _, err = console.SendLine("0") + require.NoError(t, err) + }() + + multiSelectValues, err := p.MultiSelect("Select a number", dummyDefaultValues, options) + require.NoError(t, err) + assert.Equal(t, []int{1}, multiSelectValues) + }) + t.Run("Input", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index f8bd04c35..c2233fd92 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -79,24 +79,32 @@ func (p *accessiblePrompter) newForm(groups ...*huh.Group) *huh.Form { // addDefaultsToPrompt adds default values to the prompt string. func (p *accessiblePrompter) addDefaultsToPrompt(prompt string, defaultValues []string) string { - // We don't show empty default values in the prompt. + // Removing empty defaults from the slice. defaultValues = slices.DeleteFunc(defaultValues, func(s string) bool { return s == "" }) + // Pluralizing the prompt if there are multiple default values. if len(defaultValues) == 1 { prompt = fmt.Sprintf("%s (default: %s)", prompt, defaultValues[0]) } else if len(defaultValues) > 1 { prompt = fmt.Sprintf("%s (defaults: %s)", prompt, strings.Join(defaultValues, ", ")) } + // Zero-length defaultValues means return prompt unchanged. return prompt } func (p *accessiblePrompter) Select(prompt, defaultValue string, options []string) (int, error) { var result int - formOptions := []huh.Option[int]{} + + // Remove invalid default values from the defaults slice. + if !slices.Contains(options, defaultValue) { + defaultValue = "" + } + prompt = p.addDefaultsToPrompt(prompt, []string{defaultValue}) + formOptions := []huh.Option[int]{} for i, o := range options { // If this option is the default value, assign its index // to the result variable. huh will treat it as a default selection. @@ -121,6 +129,12 @@ func (p *accessiblePrompter) Select(prompt, defaultValue string, options []strin func (p *accessiblePrompter) MultiSelect(prompt string, defaults []string, options []string) ([]int, error) { var result []int + + // Remove invalid default values from the defaults slice. + defaults = slices.DeleteFunc(defaults, func(s string) bool { + return !slices.Contains(options, s) + }) + prompt = p.addDefaultsToPrompt(prompt, defaults) formOptions := make([]huh.Option[int], len(options)) for i, o := range options { @@ -189,11 +203,13 @@ func (p *accessiblePrompter) Password(prompt string) (string, error) { func (p *accessiblePrompter) Confirm(prompt string, defaultValue bool) (bool, error) { result := defaultValue + if defaultValue { prompt = p.addDefaultsToPrompt(prompt, []string{"yes"}) } else { prompt = p.addDefaultsToPrompt(prompt, []string{"no"}) } + form := p.newForm( huh.NewGroup( huh.NewConfirm(). @@ -201,6 +217,7 @@ func (p *accessiblePrompter) Confirm(prompt string, defaultValue bool) (bool, er Value(&result), ), ) + if err := form.Run(); err != nil { return false, err }