mask.ts 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { BUILTIN_MASKS } from "../masks";
  4. import { getLang, Lang } from "../locales";
  5. import { DEFAULT_TOPIC, Message } from "./chat";
  6. import { ModelConfig, ModelType, useAppConfig } from "./config";
  7. import { StoreKey } from "../constant";
  8. export type Mask = {
  9. id: number;
  10. avatar: string;
  11. name: string;
  12. context: Message[];
  13. modelConfig: ModelConfig;
  14. lang: Lang;
  15. builtin: boolean;
  16. };
  17. export const DEFAULT_MASK_STATE = {
  18. masks: {} as Record<number, Mask>,
  19. globalMaskId: 0,
  20. };
  21. export type MaskState = typeof DEFAULT_MASK_STATE;
  22. type MaskStore = MaskState & {
  23. create: (mask?: Partial<Mask>) => Mask;
  24. update: (id: number, updater: (mask: Mask) => void) => void;
  25. delete: (id: number) => void;
  26. search: (text: string) => Mask[];
  27. get: (id?: number) => Mask | null;
  28. getAll: () => Mask[];
  29. };
  30. export const DEFAULT_MASK_ID = 1145141919810;
  31. export const DEFAULT_MASK_AVATAR = "gpt-bot";
  32. export const createEmptyMask = () =>
  33. ({
  34. id: DEFAULT_MASK_ID,
  35. avatar: DEFAULT_MASK_AVATAR,
  36. name: DEFAULT_TOPIC,
  37. context: [],
  38. modelConfig: { ...useAppConfig.getState().modelConfig },
  39. lang: getLang(),
  40. builtin: false,
  41. } as Mask);
  42. export const useMaskStore = create<MaskStore>()(
  43. persist(
  44. (set, get) => ({
  45. ...DEFAULT_MASK_STATE,
  46. create(mask) {
  47. set(() => ({ globalMaskId: get().globalMaskId + 1 }));
  48. const id = get().globalMaskId;
  49. const masks = get().masks;
  50. masks[id] = {
  51. ...createEmptyMask(),
  52. ...mask,
  53. id,
  54. builtin: false,
  55. };
  56. set(() => ({ masks }));
  57. return masks[id];
  58. },
  59. update(id, updater) {
  60. const masks = get().masks;
  61. const mask = masks[id];
  62. if (!mask) return;
  63. const updateMask = { ...mask };
  64. updater(updateMask);
  65. masks[id] = updateMask;
  66. set(() => ({ masks }));
  67. },
  68. delete(id) {
  69. const masks = get().masks;
  70. delete masks[id];
  71. set(() => ({ masks }));
  72. },
  73. get(id) {
  74. return get().masks[id ?? 1145141919810];
  75. },
  76. getAll() {
  77. const userMasks = Object.values(get().masks).sort(
  78. (a, b) => b.id - a.id,
  79. );
  80. return userMasks.concat(BUILTIN_MASKS);
  81. },
  82. search(text) {
  83. return Object.values(get().masks);
  84. },
  85. }),
  86. {
  87. name: StoreKey.Mask,
  88. version: 2,
  89. },
  90. ),
  91. );