import { create } from "zustand"; import { persist } from "zustand/middleware"; import { BUILTIN_MASKS } from "../masks"; import { getLang, Lang } from "../locales"; import { DEFAULT_TOPIC, Message } from "./chat"; import { ModelConfig, ModelType, useAppConfig } from "./config"; import { StoreKey } from "../constant"; export type Mask = { id: number; avatar: string; name: string; context: Message[]; modelConfig: ModelConfig; lang: Lang; builtin: boolean; }; export const DEFAULT_MASK_STATE = { masks: {} as Record, globalMaskId: 0, }; export type MaskState = typeof DEFAULT_MASK_STATE; type MaskStore = MaskState & { create: (mask?: Partial) => Mask; update: (id: number, updater: (mask: Mask) => void) => void; delete: (id: number) => void; search: (text: string) => Mask[]; get: (id?: number) => Mask | null; getAll: () => Mask[]; }; export const DEFAULT_MASK_ID = 1145141919810; export const DEFAULT_MASK_AVATAR = "gpt-bot"; export const createEmptyMask = () => ({ id: DEFAULT_MASK_ID, avatar: DEFAULT_MASK_AVATAR, name: DEFAULT_TOPIC, context: [], modelConfig: { ...useAppConfig.getState().modelConfig }, lang: getLang(), builtin: false, } as Mask); export const useMaskStore = create()( persist( (set, get) => ({ ...DEFAULT_MASK_STATE, create(mask) { set(() => ({ globalMaskId: get().globalMaskId + 1 })); const id = get().globalMaskId; const masks = get().masks; masks[id] = { ...createEmptyMask(), ...mask, id, builtin: false, }; set(() => ({ masks })); return masks[id]; }, update(id, updater) { const masks = get().masks; const mask = masks[id]; if (!mask) return; const updateMask = { ...mask }; updater(updateMask); masks[id] = updateMask; set(() => ({ masks })); }, delete(id) { const masks = get().masks; delete masks[id]; set(() => ({ masks })); }, get(id) { return get().masks[id ?? 1145141919810]; }, getAll() { const userMasks = Object.values(get().masks).sort( (a, b) => b.id - a.id, ); return userMasks.concat(BUILTIN_MASKS); }, search(text) { return Object.values(get().masks); }, }), { name: StoreKey.Mask, version: 2, }, ), );