import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice'; import { groupBy, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { modelConfigs: T[]; selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; groupByType?: boolean; }; type UseGroupedModelComboboxReturn = { value: ComboboxOption | undefined | null; options: GroupBase[]; onChange: ComboboxOnChange; placeholder: string; noOptionsMessage: () => string; }; const groupByBaseFunc = (model: T) => model.base.toUpperCase(); const groupByBaseAndTypeFunc = (model: T) => `${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`; const selectBaseWithSDXLFallback = createSelector(selectParamsSlice, (params) => params.model?.base ?? 'sdxl'); export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); const base = useAppSelector(selectBaseWithSDXLFallback); const shouldShowModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions); const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg; const options = useMemo[]>(() => { if (!modelConfigs) { return []; } const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc); const _options = reduce( groupedModels, (acc, val, label) => { acc.push({ label, options: val.map((model) => ({ label: model.name, value: model.key, description: (shouldShowModelDescriptions && model.description) || undefined, isDisabled: getIsDisabled ? getIsDisabled(model) : false, })), }); return acc; }, [] as GroupBase[] ); _options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1)); return _options; }, [modelConfigs, groupByType, getIsDisabled, base, shouldShowModelDescriptions]); const value = useMemo( () => options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null, [options, selectedModel] ); const _onChange = useCallback( (v) => { if (!v) { onChange(null); return; } const model = modelConfigs.find((m) => m.key === v.value); if (!model) { onChange(null); return; } onChange(model); }, [modelConfigs, onChange] ); const placeholder = useMemo(() => { if (isLoading) { return t('common.loading'); } if (options.length === 0) { return t('models.noModelsAvailable'); } return t('models.selectModel'); }, [isLoading, options, t]); const noOptionsMessage = useCallback(() => t('models.noMatchingModels'), [t]); return { options, value, onChange: _onChange, placeholder, noOptionsMessage }; };