From cfb2224176d879fb60bf0c0f92bdafaf35bb2c60 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:36:13 -0600 Subject: [PATCH] refactor(huh prompter): custom Field for MultiSelectWithSearch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the OptionsFunc-based MultiSelectWithSearch with a custom huh Field implementation. huh's OptionsFunc runs in a goroutine, causing data races with selection state and stale cache issues that made selections disappear on toggle or search changes. The custom field (multiSelectSearchField) combines a text input and multi-select list in a single field with full control over the update loop. Search runs asynchronously via tea.Cmd when the user presses Enter, with a themed spinner during loading. Selections are stored in a simple map — no goroutine races, no Eval cache, no syncAccessor. Also adds defensive validation for mismatched Keys/Labels slices from searchFunc. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/prompter/huh_prompter.go | 156 +----- internal/prompter/huh_prompter_test.go | 47 +- internal/prompter/multi_select_with_search.go | 450 ++++++++++++++++++ 3 files changed, 498 insertions(+), 155 deletions(-) create mode 100644 internal/prompter/multi_select_with_search.go diff --git a/internal/prompter/huh_prompter.go b/internal/prompter/huh_prompter.go index 95bf43fab..40a27d507 100644 --- a/internal/prompter/huh_prompter.go +++ b/internal/prompter/huh_prompter.go @@ -3,7 +3,6 @@ package prompter import ( "fmt" "slices" - "sync" "charm.land/huh/v2" "github.com/cli/cli/v2/internal/ghinstance" @@ -93,162 +92,19 @@ func (p *huhPrompter) MultiSelect(prompt string, defaults []string, options []st return *result, nil } -// searchOptionsBinding is used as the OptionsFunc binding for MultiSelectWithSearch. -// By including both the search query and selected values, the binding hash changes -// whenever either changes. This prevents huh's internal Eval cache from serving -// stale option sets that would overwrite the user's current selections. -type searchOptionsBinding struct { - Query *string - Selected *[]string -} - -// syncAccessor is a thread-safe huh.Accessor implementation. -// huh calls OptionsFunc from a goroutine while the main event loop -// writes field values via Set(). This accessor synchronizes both -// paths through the same mutex. -type syncAccessor[T any] struct { - mu *sync.Mutex - value T -} - -func (a *syncAccessor[T]) Get() T { - a.mu.Lock() - defer a.mu.Unlock() - return a.value -} - -func (a *syncAccessor[T]) Set(value T) { - a.mu.Lock() - defer a.mu.Unlock() - a.value = value -} - -func (p *huhPrompter) buildMultiSelectWithSearchForm(prompt, searchPrompt string, defaultValues, persistentValues []string, searchFunc func(string) MultiSelectSearchResult) (*huh.Form, *syncAccessor[[]string]) { - var mu sync.Mutex - - queryAccessor := &syncAccessor[string]{mu: &mu} - selectAccessor := &syncAccessor[[]string]{mu: &mu, value: slices.Clone(defaultValues)} - - optionKeyLabels := make(map[string]string) - for _, k := range defaultValues { - optionKeyLabels[k] = k - } - - // Cache searchFunc results locally keyed by query string. - // This avoids redundant calls when the OptionsFunc binding hash changes - // due to selection changes (not query changes). - searchCacheValid := false - var cachedSearchQuery string - var cachedSearchResult MultiSelectSearchResult - - buildOptions := func() []huh.Option[string] { - mu.Lock() - query := queryAccessor.value - needsFetch := !searchCacheValid || query != cachedSearchQuery - mu.Unlock() - - if needsFetch { - result := searchFunc(query) - mu.Lock() - cachedSearchResult = result - cachedSearchQuery = query - searchCacheValid = true - mu.Unlock() - } - - mu.Lock() - defer mu.Unlock() - - selectedValues := selectAccessor.value - result := cachedSearchResult - - if result.Err != nil { - return nil - } - for i, k := range result.Keys { - optionKeyLabels[k] = result.Labels[i] - } - - var formOptions []huh.Option[string] - seen := make(map[string]bool) - - // 1. Currently selected values (persisted across searches). - for _, k := range selectedValues { - if seen[k] { - continue - } - seen[k] = true - l := optionKeyLabels[k] - if l == "" { - l = k - } - formOptions = append(formOptions, huh.NewOption(l, k).Selected(true)) - } - - // 2. Search results. - for i, k := range result.Keys { - if seen[k] { - continue - } - seen[k] = true - l := result.Labels[i] - if l == "" { - l = k - } - formOptions = append(formOptions, huh.NewOption(l, k)) - } - - // 3. Persistent options. - for _, k := range persistentValues { - if seen[k] { - continue - } - seen[k] = true - l := optionKeyLabels[k] - if l == "" { - l = k - } - formOptions = append(formOptions, huh.NewOption(l, k)) - } - - if len(formOptions) == 0 { - formOptions = append(formOptions, huh.NewOption("No results", "")) - } - - return formOptions - } - - binding := &searchOptionsBinding{ - Query: &queryAccessor.value, - Selected: &selectAccessor.value, - } - - form := p.newForm( - huh.NewGroup( - huh.NewInput(). - Title(searchPrompt). - Placeholder("Type to search, Ctrl+U to clear"). - Accessor(queryAccessor), - huh.NewMultiSelect[string](). - Title(prompt). - Options(buildOptions()...). - OptionsFunc(func() []huh.Option[string] { - return buildOptions() - }, binding). - Accessor(selectAccessor). - Limit(0), - ), - ) - return form, selectAccessor +func (p *huhPrompter) buildMultiSelectWithSearchForm(prompt, searchPrompt string, defaultValues, persistentValues []string, searchFunc func(string) MultiSelectSearchResult) (*huh.Form, *multiSelectSearchField) { + field := newMultiSelectSearchField(prompt, searchPrompt, defaultValues, persistentValues, searchFunc) + form := p.newForm(huh.NewGroup(field)) + return form, field } func (p *huhPrompter) MultiSelectWithSearch(prompt, searchPrompt string, defaultValues, persistentValues []string, searchFunc func(string) MultiSelectSearchResult) ([]string, error) { - form, accessor := p.buildMultiSelectWithSearchForm(prompt, searchPrompt, defaultValues, persistentValues, searchFunc) + form, field := p.buildMultiSelectWithSearchForm(prompt, searchPrompt, defaultValues, persistentValues, searchFunc) err := form.Run() if err != nil { return nil, err } - return accessor.Get(), nil + return field.selectedKeys(), nil } func (p *huhPrompter) buildInputForm(prompt, defaultValue string) (*huh.Form, *string) { diff --git a/internal/prompter/huh_prompter_test.go b/internal/prompter/huh_prompter_test.go index e039038ad..404867d23 100644 --- a/internal/prompter/huh_prompter_test.go +++ b/internal/prompter/huh_prompter_test.go @@ -450,7 +450,7 @@ func TestHuhPrompterMultiSelectWithSearch(t *testing.T) { "Select", "Search", tt.defaults, tt.persistent, staticSearchFunc, ) runForm(t, f, tt.ix) - assert.Equal(t, tt.wantResult, result.Get()) + assert.Equal(t, tt.wantResult, result.selectedKeys()) }) } } @@ -482,7 +482,7 @@ func TestHuhPrompterMultiSelectWithSearchPersistence(t *testing.T) { tab(), waitForOptions(), enter(), // submit — result-a should persist )) - assert.Equal(t, []string{"result-a"}, result.Get()) + assert.Equal(t, []string{"result-a"}, result.selectedKeys()) }) t.Run("empty search results shows no-results placeholder", func(t *testing.T) { emptySearchFunc := func(query string) MultiSelectSearchResult { @@ -492,10 +492,10 @@ func TestHuhPrompterMultiSelectWithSearchPersistence(t *testing.T) { f, result := p.buildMultiSelectWithSearchForm( "Select", "Search", nil, nil, emptySearchFunc, ) - // With no results, the "No results" placeholder is shown but nothing - // is selected, so submitting returns empty. + // With no results, the "No results" message is shown. + // Toggle does nothing, submitting returns empty. runForm(t, f, newInteraction(tab(), waitForOptions(), toggle(), enter())) - assert.Equal(t, []string{""}, result.Get()) + assert.Equal(t, []string{}, result.selectedKeys()) }) } @@ -581,3 +581,40 @@ func TestHuhPrompterInputHostname(t *testing.T) { }) } } + +func TestHuhPrompterMultiSelectWithSearchBackspace(t *testing.T) { + // Simulate real API latency and non-overlapping results. + staticSearchFunc := func(query string) MultiSelectSearchResult { + time.Sleep(100 * time.Millisecond) // simulate API latency + if query == "" { + return MultiSelectSearchResult{ + Keys: []string{"alice", "bob"}, + Labels: []string{"Alice", "Bob"}, + } + } + return MultiSelectSearchResult{ + Keys: []string{"frank", "fiona"}, + Labels: []string{"Frank", "Fiona"}, + } + } + + t.Run("selections persist after backspacing search query", func(t *testing.T) { + p := newTestHuhPrompter() + f, result := p.buildMultiSelectWithSearchForm( + "Select", "Search", nil, nil, staticSearchFunc, + ) + longWait := interactionStep{delay: 300 * time.Millisecond} + runForm(t, f, newInteraction( + tab(), longWait, + toggle(), // toggle alice + shiftTab(), // back to search input + typeKeys("f"), // type "f" + longWait, // wait for API + OptionsFunc + typeKeys("\x7f"), // backspace to "" + longWait, // wait for cache/API + tab(), longWait, + enter(), + )) + assert.Equal(t, []string{"alice"}, result.selectedKeys()) + }) +} diff --git a/internal/prompter/multi_select_with_search.go b/internal/prompter/multi_select_with_search.go new file mode 100644 index 000000000..cc1302fbd --- /dev/null +++ b/internal/prompter/multi_select_with_search.go @@ -0,0 +1,450 @@ +package prompter + +import ( + "fmt" + "io" + "strings" + "time" + + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textinput" + "charm.land/bubbles/v2/viewport" + tea "charm.land/bubbletea/v2" + "charm.land/huh/v2" + "charm.land/lipgloss/v2" +) + +// multiSelectSearchField is a custom huh Field that combines a text input +// for searching with a multi-select list. Unlike huh's built-in OptionsFunc, +// search results are loaded synchronously when the user presses Enter in +// the search input, avoiding goroutine races with selection state. +type multiSelectSearchField struct { + // configuration + title string + searchTitle string + placeholder string + searchFunc func(string) MultiSelectSearchResult + + // state + mode msMode // which sub-component has focus + search textinput.Model + cursor int + viewport viewport.Model + loading bool + loadingStart time.Time + spinner spinner.Model + + // options and selections + options []msOption + selected map[string]bool // key → selected (source of truth) + optionLabels map[string]string // key → display label + lastQuery string + defaultValues []string + persistent []string + + // field metadata + key string + err error + focused bool + width int + height int + theme huh.Theme + hasDarkBg bool + position huh.FieldPosition +} + +type msMode int + +const ( + msModeSearch msMode = iota + msModeSelect +) + +type msOption struct { + label string + value string +} + +// msSearchResultMsg carries search results back from the background goroutine. +type msSearchResultMsg struct { + query string + result MultiSelectSearchResult +} + +func newMultiSelectSearchField( + title, searchTitle string, + defaults, persistent []string, + searchFunc func(string) MultiSelectSearchResult, +) *multiSelectSearchField { + ti := textinput.New() + ti.Prompt = "> " + ti.Placeholder = "Type to search" + ti.Focus() + + selected := make(map[string]bool) + for _, k := range defaults { + selected[k] = true + } + + m := &multiSelectSearchField{ + title: title, + searchTitle: searchTitle, + searchFunc: searchFunc, + mode: msModeSearch, + search: ti, + selected: selected, + optionLabels: make(map[string]string), + defaultValues: defaults, + persistent: persistent, + height: 10, + spinner: spinner.New(spinner.WithSpinner(spinner.Line)), + } + + // Load initial results synchronously (form hasn't started yet). + m.applySearchResult("", m.searchFunc("")) + + return m +} + +// startSearch launches an async search and returns a tea.Cmd that will +// deliver the result via msSearchResultMsg. +func (m *multiSelectSearchField) startSearch(query string) tea.Cmd { + m.loading = true + m.loadingStart = time.Now() + searchFunc := m.searchFunc + return tea.Batch( + func() tea.Msg { + return msSearchResultMsg{query: query, result: searchFunc(query)} + }, + m.spinner.Tick, + ) +} + +// applySearchResult processes a completed search and rebuilds the option list. +func (m *multiSelectSearchField) applySearchResult(query string, result MultiSelectSearchResult) { + m.loading = false + m.lastQuery = query + if result.Err != nil { + m.err = result.Err + return + } + if len(result.Keys) != len(result.Labels) { + m.err = fmt.Errorf("search returned mismatched keys and labels: %d keys, %d labels", len(result.Keys), len(result.Labels)) + return + } + + for i, k := range result.Keys { + m.optionLabels[k] = result.Labels[i] + } + + // Build option list: selected items first, then results, then persistent. + var options []msOption + seen := make(map[string]bool) + + // 1. Currently selected items. + for _, k := range m.selectedKeys() { + if seen[k] { + continue + } + seen[k] = true + options = append(options, msOption{label: m.label(k), value: k}) + } + + // 2. Search results. + for i, k := range result.Keys { + if seen[k] { + continue + } + seen[k] = true + l := result.Labels[i] + if l == "" { + l = k + } + options = append(options, msOption{label: l, value: k}) + } + + // 3. Persistent options. + for _, k := range m.persistent { + if seen[k] { + continue + } + seen[k] = true + options = append(options, msOption{label: m.label(k), value: k}) + } + + m.options = options + m.cursor = 0 + m.err = nil +} + +func (m *multiSelectSearchField) selectedKeys() []string { + keys := make([]string, 0) + // Maintain order: defaults first, then any added during this session. + seen := make(map[string]bool) + for _, k := range m.defaultValues { + if m.selected[k] && !seen[k] { + keys = append(keys, k) + seen[k] = true + } + } + for _, o := range m.options { + if m.selected[o.value] && !seen[o.value] { + keys = append(keys, o.value) + seen[o.value] = true + } + } + return keys +} + +func (m *multiSelectSearchField) label(key string) string { + if l, ok := m.optionLabels[key]; ok && l != "" { + return l + } + return key +} + +// --- huh.Field interface --- + +func (m *multiSelectSearchField) Init() tea.Cmd { + return nil +} + +func (m *multiSelectSearchField) Update(msg tea.Msg) (huh.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.BackgroundColorMsg: + m.hasDarkBg = msg.IsDark() + + case msSearchResultMsg: + m.applySearchResult(msg.query, msg.result) + m.mode = msModeSelect + m.search.Blur() + return m, nil + + case spinner.TickMsg: + if !m.loading { + break + } + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + return m, cmd + + case tea.KeyPressMsg: + if m.loading { + return m, nil // ignore keys while loading + } + switch m.mode { + case msModeSearch: + return m.updateSearch(msg) + case msModeSelect: + return m.updateSelect(msg) + } + } + return m, nil +} + +func (m *multiSelectSearchField) updateSearch(msg tea.KeyPressMsg) (huh.Model, tea.Cmd) { + switch { + case key.Matches(msg, key.NewBinding(key.WithKeys("enter", "tab"))): + query := m.search.Value() + if query == m.lastQuery { + // Query unchanged — just switch to select mode. + m.mode = msModeSelect + m.search.Blur() + return m, nil + } + // New query — search in background with spinner. + return m, m.startSearch(query) + + case key.Matches(msg, key.NewBinding(key.WithKeys("shift+tab"))): + return m, huh.PrevField + + default: + var cmd tea.Cmd + m.search, cmd = m.search.Update(msg) + return m, cmd + } +} + +func (m *multiSelectSearchField) updateSelect(msg tea.KeyPressMsg) (huh.Model, tea.Cmd) { + switch { + case key.Matches(msg, key.NewBinding(key.WithKeys("shift+tab"))): + // Back to search mode. + m.mode = msModeSearch + m.search.Focus() + return m, nil + + case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))): + return m, huh.NextField + + case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))): + if m.cursor > 0 { + m.cursor-- + } + return m, nil + + case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))): + if m.cursor < len(m.options)-1 { + m.cursor++ + } + return m, nil + + case key.Matches(msg, key.NewBinding(key.WithKeys("space", "x"))): + if len(m.options) > 0 { + k := m.options[m.cursor].value + m.selected[k] = !m.selected[k] + if !m.selected[k] { + delete(m.selected, k) + } + } + return m, nil + } + + return m, nil +} + +func (m *multiSelectSearchField) View() string { + styles := m.activeStyles() + var sb strings.Builder + + // Title. + if m.title != "" { + sb.WriteString(styles.Title.Render(m.title)) + sb.WriteString("\n") + } + + // Search input. + if m.searchTitle != "" { + sb.WriteString(styles.Description.Render(m.searchTitle)) + sb.WriteString("\n") + } + sb.WriteString(m.search.View()) + sb.WriteString("\n") + + // Options list. + if m.loading { + m.spinner.Style = styles.MultiSelectSelector.UnsetString() + sb.WriteString(m.spinner.View() + " Loading...") + sb.WriteString("\n") + } else if len(m.options) == 0 { + sb.WriteString(styles.UnselectedOption.Render(" No results")) + sb.WriteString("\n") + } else { + for i, o := range m.options { + cursor := m.mode == msModeSelect && i == m.cursor + isSelected := m.selected[o.value] + sb.WriteString(m.renderOption(o, cursor, isSelected)) + sb.WriteString("\n") + } + } + + return styles.Base.Width(m.width).Height(m.height).Render(sb.String()) +} + +func (m *multiSelectSearchField) renderOption(o msOption, cursor, selected bool) string { + styles := m.activeStyles() + + var parts []string + if cursor { + parts = append(parts, styles.MultiSelectSelector.String()) + } else { + parts = append(parts, strings.Repeat(" ", lipgloss.Width(styles.MultiSelectSelector.String()))) + } + if selected { + parts = append(parts, styles.SelectedPrefix.String()) + parts = append(parts, styles.SelectedOption.Render(o.label)) + } else { + parts = append(parts, styles.UnselectedPrefix.String()) + parts = append(parts, styles.UnselectedOption.Render(o.label)) + } + return lipgloss.JoinHorizontal(lipgloss.Left, parts...) +} + +func (m *multiSelectSearchField) activeStyles() *huh.FieldStyles { + theme := m.theme + if theme == nil { + theme = huh.ThemeFunc(huh.ThemeCharm) + } + if m.focused { + return &theme.Theme(m.hasDarkBg).Focused + } + return &theme.Theme(m.hasDarkBg).Blurred +} + +func (m *multiSelectSearchField) Focus() tea.Cmd { + m.focused = true + if m.mode == msModeSearch { + return m.search.Focus() + } + return nil +} + +func (m *multiSelectSearchField) Blur() tea.Cmd { + m.focused = false + m.search.Blur() + return nil +} + +func (m *multiSelectSearchField) Error() error { return m.err } +func (*multiSelectSearchField) Skip() bool { return false } +func (*multiSelectSearchField) Zoom() bool { return false } +func (m *multiSelectSearchField) GetKey() string { return m.key } +func (m *multiSelectSearchField) GetValue() any { return m.selectedKeys() } +func (m *multiSelectSearchField) Run() error { return huh.Run(m) } +func (m *multiSelectSearchField) RunAccessible(w io.Writer, r io.Reader) error { + _, _ = fmt.Fprintln(w, "MultiSelectWithSearch accessible mode not implemented") + return nil +} + +func (m *multiSelectSearchField) KeyBinds() []key.Binding { + if m.mode == msModeSearch { + return []key.Binding{ + key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "search")), + key.NewBinding(key.WithKeys("shift+tab"), key.WithHelp("shift+tab", "back")), + } + } + return []key.Binding{ + key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "toggle")), + key.NewBinding(key.WithKeys("up"), key.WithHelp("↑", "up")), + key.NewBinding(key.WithKeys("down"), key.WithHelp("↓", "down")), + key.NewBinding(key.WithKeys("shift+tab"), key.WithHelp("shift+tab", "search")), + key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "confirm")), + } +} + +func (m *multiSelectSearchField) WithTheme(theme huh.Theme) huh.Field { + if m.theme != nil { + return m + } + m.theme = theme + + styles := theme.Theme(m.hasDarkBg) + st := m.search.Styles() + st.Cursor.Color = styles.Focused.TextInput.Cursor.GetForeground() + st.Focused.Prompt = styles.Focused.TextInput.Prompt + st.Focused.Text = styles.Focused.TextInput.Text + st.Focused.Placeholder = styles.Focused.TextInput.Placeholder + m.search.SetStyles(st) + + return m +} + +func (m *multiSelectSearchField) WithKeyMap(k *huh.KeyMap) huh.Field { + return m +} + +func (m *multiSelectSearchField) WithWidth(width int) huh.Field { + m.width = width + m.search.SetWidth(width) + return m +} + +func (m *multiSelectSearchField) WithHeight(height int) huh.Field { + m.height = height + return m +} + +func (m *multiSelectSearchField) WithPosition(p huh.FieldPosition) huh.Field { + m.position = p + return m +}