123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- import { create } from "zustand";
- import { persist } from "zustand/middleware";
- import { type ChatCompletionResponseMessage } from "openai";
- import {
- ControllerPool,
- requestChatStream,
- requestWithPrompt,
- } from "../requests";
- import { trimTopic } from "../utils";
- import Locale from "../locales";
- import { showToast } from "../components/ui-lib";
- import { ModelType } from "./config";
- import { createEmptyMask, Mask } from "./mask";
- import { StoreKey } from "../constant";
- export type Message = ChatCompletionResponseMessage & {
- date: string;
- streaming?: boolean;
- isError?: boolean;
- id?: number;
- model?: ModelType;
- };
- export function createMessage(override: Partial<Message>): Message {
- return {
- id: Date.now(),
- date: new Date().toLocaleString(),
- role: "user",
- content: "",
- ...override,
- };
- }
- export const ROLES: Message["role"][] = ["system", "user", "assistant"];
- export interface ChatStat {
- tokenCount: number;
- wordCount: number;
- charCount: number;
- }
- export interface ChatSession {
- id: number;
- topic: string;
- memoryPrompt: string;
- messages: Message[];
- stat: ChatStat;
- lastUpdate: number;
- lastSummarizeIndex: number;
- mask: Mask;
- }
- export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
- export const BOT_HELLO: Message = createMessage({
- role: "assistant",
- content: Locale.Store.BotHello,
- });
- function createEmptySession(): ChatSession {
- return {
- id: Date.now() + Math.random(),
- topic: DEFAULT_TOPIC,
- memoryPrompt: "",
- messages: [],
- stat: {
- tokenCount: 0,
- wordCount: 0,
- charCount: 0,
- },
- lastUpdate: Date.now(),
- lastSummarizeIndex: 0,
- mask: createEmptyMask(),
- };
- }
- interface ChatStore {
- sessions: ChatSession[];
- currentSessionIndex: number;
- globalId: number;
- clearSessions: () => void;
- moveSession: (from: number, to: number) => void;
- selectSession: (index: number) => void;
- newSession: (mask?: Mask) => void;
- deleteSession: (index: number) => void;
- currentSession: () => ChatSession;
- onNewMessage: (message: Message) => void;
- onUserInput: (content: string) => Promise<void>;
- summarizeSession: () => void;
- updateStat: (message: Message) => void;
- updateCurrentSession: (updater: (session: ChatSession) => void) => void;
- updateMessage: (
- sessionIndex: number,
- messageIndex: number,
- updater: (message?: Message) => void,
- ) => void;
- resetSession: () => void;
- getMessagesWithMemory: () => Message[];
- getMemoryPrompt: () => Message;
- clearAllData: () => void;
- }
- function countMessages(msgs: Message[]) {
- return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
- }
- export const useChatStore = create<ChatStore>()(
- persist(
- (set, get) => ({
- sessions: [createEmptySession()],
- currentSessionIndex: 0,
- globalId: 0,
- clearSessions() {
- set(() => ({
- sessions: [createEmptySession()],
- currentSessionIndex: 0,
- }));
- },
- selectSession(index: number) {
- set({
- currentSessionIndex: index,
- });
- },
- moveSession(from: number, to: number) {
- set((state) => {
- const { sessions, currentSessionIndex: oldIndex } = state;
- // move the session
- const newSessions = [...sessions];
- const session = newSessions[from];
- newSessions.splice(from, 1);
- newSessions.splice(to, 0, session);
- // modify current session id
- let newIndex = oldIndex === from ? to : oldIndex;
- if (oldIndex > from && oldIndex <= to) {
- newIndex -= 1;
- } else if (oldIndex < from && oldIndex >= to) {
- newIndex += 1;
- }
- return {
- currentSessionIndex: newIndex,
- sessions: newSessions,
- };
- });
- },
- newSession(mask) {
- const session = createEmptySession();
- set(() => ({ globalId: get().globalId + 1 }));
- session.id = get().globalId;
- if (mask) {
- session.mask = { ...mask };
- session.topic = mask.name;
- }
- set((state) => ({
- currentSessionIndex: 0,
- sessions: [session].concat(state.sessions),
- }));
- },
- deleteSession(index) {
- const deletingLastSession = get().sessions.length === 1;
- const deletedSession = get().sessions.at(index);
- if (!deletedSession) return;
- const sessions = get().sessions.slice();
- sessions.splice(index, 1);
- const currentIndex = get().currentSessionIndex;
- let nextIndex = Math.min(
- currentIndex - Number(index < currentIndex),
- sessions.length - 1,
- );
- if (deletingLastSession) {
- nextIndex = 0;
- sessions.push(createEmptySession());
- }
- // for undo delete action
- const restoreState = {
- currentSessionIndex: get().currentSessionIndex,
- sessions: get().sessions.slice(),
- };
- set(() => ({
- currentSessionIndex: nextIndex,
- sessions,
- }));
- showToast(
- Locale.Home.DeleteToast,
- {
- text: Locale.Home.Revert,
- onClick() {
- set(() => restoreState);
- },
- },
- 5000,
- );
- },
- currentSession() {
- let index = get().currentSessionIndex;
- const sessions = get().sessions;
- if (index < 0 || index >= sessions.length) {
- index = Math.min(sessions.length - 1, Math.max(0, index));
- set(() => ({ currentSessionIndex: index }));
- }
- const session = sessions[index];
- return session;
- },
- onNewMessage(message) {
- get().updateCurrentSession((session) => {
- session.lastUpdate = Date.now();
- });
- get().updateStat(message);
- get().summarizeSession();
- },
- async onUserInput(content) {
- const session = get().currentSession();
- const modelConfig = session.mask.modelConfig;
- const userMessage: Message = createMessage({
- role: "user",
- content,
- });
- const botMessage: Message = createMessage({
- role: "assistant",
- streaming: true,
- id: userMessage.id! + 1,
- model: modelConfig.model,
- });
- const systemInfo = createMessage({
- role: "system",
- content: `IMPRTANT: You are a virtual assistant powered by the ${
- modelConfig.model
- } model, now time is ${new Date().toLocaleString()}}`,
- id: botMessage.id! + 1,
- });
- // get recent messages
- const systemMessages = [systemInfo];
- const recentMessages = get().getMessagesWithMemory();
- const sendMessages = systemMessages.concat(
- recentMessages.concat(userMessage),
- );
- const sessionIndex = get().currentSessionIndex;
- const messageIndex = get().currentSession().messages.length + 1;
- // save user's and bot's message
- get().updateCurrentSession((session) => {
- session.messages.push(userMessage);
- session.messages.push(botMessage);
- });
- // make request
- console.log("[User Input] ", sendMessages);
- requestChatStream(sendMessages, {
- onMessage(content, done) {
- // stream response
- if (done) {
- botMessage.streaming = false;
- botMessage.content = content;
- get().onNewMessage(botMessage);
- ControllerPool.remove(
- sessionIndex,
- botMessage.id ?? messageIndex,
- );
- } else {
- botMessage.content = content;
- set(() => ({}));
- }
- },
- onError(error, statusCode) {
- const isAborted = error.message.includes("aborted");
- if (statusCode === 401) {
- botMessage.content = Locale.Error.Unauthorized;
- } else if (!isAborted) {
- botMessage.content += "\n\n" + Locale.Store.Error;
- }
- botMessage.streaming = false;
- userMessage.isError = !isAborted;
- botMessage.isError = !isAborted;
- set(() => ({}));
- ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
- },
- onController(controller) {
- // collect controller for stop/retry
- ControllerPool.addController(
- sessionIndex,
- botMessage.id ?? messageIndex,
- controller,
- );
- },
- modelConfig: { ...modelConfig },
- });
- },
- getMemoryPrompt() {
- const session = get().currentSession();
- return {
- role: "system",
- content:
- session.memoryPrompt.length > 0
- ? Locale.Store.Prompt.History(session.memoryPrompt)
- : "",
- date: "",
- } as Message;
- },
- getMessagesWithMemory() {
- const session = get().currentSession();
- const modelConfig = session.mask.modelConfig;
- const messages = session.messages.filter((msg) => !msg.isError);
- const n = messages.length;
- const context = session.mask.context.slice();
- // long term memory
- if (
- modelConfig.sendMemory &&
- session.memoryPrompt &&
- session.memoryPrompt.length > 0
- ) {
- const memoryPrompt = get().getMemoryPrompt();
- context.push(memoryPrompt);
- }
- // get short term and unmemoried long term memory
- const shortTermMemoryMessageIndex = Math.max(
- 0,
- n - modelConfig.historyMessageCount,
- );
- const longTermMemoryMessageIndex = session.lastSummarizeIndex;
- const oldestIndex = Math.max(
- shortTermMemoryMessageIndex,
- longTermMemoryMessageIndex,
- );
- const threshold = modelConfig.compressMessageLengthThreshold;
- // get recent messages as many as possible
- const reversedRecentMessages = [];
- for (
- let i = n - 1, count = 0;
- i >= oldestIndex && count < threshold;
- i -= 1
- ) {
- const msg = messages[i];
- if (!msg || msg.isError) continue;
- count += msg.content.length;
- reversedRecentMessages.push(msg);
- }
- // concat
- const recentMessages = context.concat(reversedRecentMessages.reverse());
- return recentMessages;
- },
- updateMessage(
- sessionIndex: number,
- messageIndex: number,
- updater: (message?: Message) => void,
- ) {
- const sessions = get().sessions;
- const session = sessions.at(sessionIndex);
- const messages = session?.messages;
- updater(messages?.at(messageIndex));
- set(() => ({ sessions }));
- },
- resetSession() {
- get().updateCurrentSession((session) => {
- session.messages = [];
- session.memoryPrompt = "";
- });
- },
- summarizeSession() {
- const session = get().currentSession();
- // should summarize topic after chating more than 50 words
- const SUMMARIZE_MIN_LEN = 50;
- if (
- session.topic === DEFAULT_TOPIC &&
- countMessages(session.messages) >= SUMMARIZE_MIN_LEN
- ) {
- requestWithPrompt(session.messages, Locale.Store.Prompt.Topic, {
- model: "gpt-3.5-turbo",
- }).then((res) => {
- get().updateCurrentSession(
- (session) =>
- (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
- );
- });
- }
- const modelConfig = session.mask.modelConfig;
- let toBeSummarizedMsgs = session.messages.slice(
- session.lastSummarizeIndex,
- );
- const historyMsgLength = countMessages(toBeSummarizedMsgs);
- if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
- const n = toBeSummarizedMsgs.length;
- toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
- Math.max(0, n - modelConfig.historyMessageCount),
- );
- }
- // add memory prompt
- toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
- const lastSummarizeIndex = session.messages.length;
- console.log(
- "[Chat History] ",
- toBeSummarizedMsgs,
- historyMsgLength,
- modelConfig.compressMessageLengthThreshold,
- );
- if (
- historyMsgLength > modelConfig.compressMessageLengthThreshold &&
- session.mask.modelConfig.sendMemory
- ) {
- requestChatStream(
- toBeSummarizedMsgs.concat({
- role: "system",
- content: Locale.Store.Prompt.Summarize,
- date: "",
- }),
- {
- overrideModel: "gpt-3.5-turbo",
- onMessage(message, done) {
- session.memoryPrompt = message;
- if (done) {
- console.log("[Memory] ", session.memoryPrompt);
- session.lastSummarizeIndex = lastSummarizeIndex;
- }
- },
- onError(error) {
- console.error("[Summarize] ", error);
- },
- },
- );
- }
- },
- updateStat(message) {
- get().updateCurrentSession((session) => {
- session.stat.charCount += message.content.length;
- // TODO: should update chat count and word count
- });
- },
- updateCurrentSession(updater) {
- const sessions = get().sessions;
- const index = get().currentSessionIndex;
- updater(sessions[index]);
- set(() => ({ sessions }));
- },
- clearAllData() {
- localStorage.clear();
- location.reload();
- },
- }),
- {
- name: StoreKey.Chat,
- version: 2,
- migrate(persistedState, version) {
- const state = persistedState as any;
- const newState = JSON.parse(JSON.stringify(state)) as ChatStore;
- if (version < 2) {
- newState.globalId = 0;
- newState.sessions = [];
- const oldSessions = state.sessions;
- for (const oldSession of oldSessions) {
- const newSession = createEmptySession();
- newSession.topic = oldSession.topic;
- newSession.messages = [...oldSession.messages];
- newSession.mask.modelConfig.sendMemory = true;
- newSession.mask.modelConfig.historyMessageCount = 4;
- newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
- newState.sessions.push(newSession);
- }
- }
- return newState;
- },
- },
- ),
- );
|