UNPKG

@assistant-ui/react

Version:

TypeScript/React library for AI Chat

500 lines 19.5 kB
"use client"; import { jsx as _jsx } from "react/jsx-runtime"; import { generateId } from "../../../internal.js"; import { RemoteThreadListHookInstanceManager } from "./RemoteThreadListHookInstanceManager.js"; import { BaseSubscribable } from "./BaseSubscribable.js"; import { EMPTY_THREAD_CORE } from "./EMPTY_THREAD_CORE.js"; import { OptimisticState } from "./OptimisticState.js"; import { Fragment, useEffect, useId } from "react"; import { create } from "zustand"; import { AssistantMessageStream } from "assistant-stream"; import { RuntimeAdapterProvider } from "../adapters/RuntimeAdapterProvider.js"; function createThreadMappingId(id) { return id; } const getThreadData = (state, threadIdOrRemoteId) => { const idx = state.threadIdMap[threadIdOrRemoteId]; if (idx === undefined) return undefined; return state.threadData[idx]; }; const updateStatusReducer = (state, threadIdOrRemoteId, newStatus) => { const data = getThreadData(state, threadIdOrRemoteId); if (!data) return state; const { id, remoteId, status: lastStatus } = data; if (lastStatus === newStatus) return state; const newState = { ...state }; // lastStatus switch (lastStatus) { case "new": newState.newThreadId = undefined; break; case "regular": newState.threadIds = newState.threadIds.filter((t) => t !== id); break; case "archived": newState.archivedThreadIds = newState.archivedThreadIds.filter((t) => t !== id); break; default: { const _exhaustiveCheck = lastStatus; throw new Error(`Unsupported state: ${_exhaustiveCheck}`); } } // newStatus switch (newStatus) { case "regular": newState.threadIds = [id, ...newState.threadIds]; break; case "archived": newState.archivedThreadIds = [id, ...newState.archivedThreadIds]; break; case "deleted": newState.threadData = Object.fromEntries(Object.entries(newState.threadData).filter(([key]) => key !== id)); newState.threadIdMap = Object.fromEntries(Object.entries(newState.threadIdMap).filter(([key]) => key !== id && key !== remoteId)); break; default: { const _exhaustiveCheck = newStatus; throw new Error(`Unsupported state: ${_exhaustiveCheck}`); } } if (newStatus !== "deleted") { newState.threadData = { ...newState.threadData, [id]: { ...data, status: newStatus, }, }; } return newState; }; export class RemoteThreadListThreadListRuntimeCore extends BaseSubscribable { contextProvider; _options; _hookManager; _loadThreadsPromise; _mainThreadId; _state = new OptimisticState({ isLoading: false, newThreadId: undefined, threadIds: [], archivedThreadIds: [], threadIdMap: {}, threadData: {}, }); get threadData() { return this._state.value.threadData; } getLoadThreadsPromise() { // TODO this needs to be cached in case this promise is loaded during suspense if (!this._loadThreadsPromise) { this._loadThreadsPromise = this._state .optimisticUpdate({ execute: () => this._options.adapter.list(), loading: (state) => { return { ...state, isLoading: true, }; }, then: (state, l) => { const newThreadIds = []; const newArchivedThreadIds = []; const newThreadIdMap = {}; const newThreadData = {}; for (const thread of l.threads) { switch (thread.status) { case "regular": newThreadIds.push(thread.remoteId); break; case "archived": newArchivedThreadIds.push(thread.remoteId); break; default: { const _exhaustiveCheck = thread.status; throw new Error(`Unsupported state: ${_exhaustiveCheck}`); } } const mappingId = createThreadMappingId(thread.remoteId); newThreadIdMap[thread.remoteId] = mappingId; newThreadData[mappingId] = { id: thread.remoteId, remoteId: thread.remoteId, externalId: thread.externalId, status: thread.status, title: thread.title, initializeTask: Promise.resolve({ remoteId: thread.remoteId, externalId: thread.externalId, }), }; } return { ...state, threadIds: newThreadIds, archivedThreadIds: newArchivedThreadIds, threadIdMap: { ...state.threadIdMap, ...newThreadIdMap, }, threadData: { ...state.threadData, ...newThreadData, }, }; }, }) .then(() => { }); } return this._loadThreadsPromise; } constructor(options, contextProvider) { super(); this.contextProvider = contextProvider; this._state.subscribe(() => this._notifySubscribers()); this._hookManager = new RemoteThreadListHookInstanceManager(options.runtimeHook); this.useProvider = create(() => ({ Provider: options.adapter.unstable_Provider ?? Fragment, })); this.__internal_setOptions(options); this.switchToNewThread(); } useProvider; __internal_setOptions(options) { if (this._options === options) return; this._options = options; const Provider = options.adapter.unstable_Provider ?? Fragment; if (Provider !== this.useProvider.getState().Provider) { this.useProvider.setState({ Provider }, true); } this._hookManager.setRuntimeHook(options.runtimeHook); } __internal_load() { this.getLoadThreadsPromise(); // begin loading on initial bind } get isLoading() { return this._state.value.isLoading; } get threadIds() { return this._state.value.threadIds; } get archivedThreadIds() { return this._state.value.archivedThreadIds; } get newThreadId() { return this._state.value.newThreadId; } get mainThreadId() { return this._mainThreadId; } getMainThreadRuntimeCore() { const result = this._hookManager.getThreadRuntimeCore(this._mainThreadId); if (!result) return EMPTY_THREAD_CORE; return result; } getThreadRuntimeCore(threadIdOrRemoteId) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); const result = this._hookManager.getThreadRuntimeCore(data.id); if (!result) throw new Error("Thread not found"); return result; } getItemById(threadIdOrRemoteId) { return getThreadData(this._state.value, threadIdOrRemoteId); } async switchToThread(threadIdOrRemoteId) { let data = this.getItemById(threadIdOrRemoteId); if (!data) { const remoteMetadata = await this._options.adapter.fetch(threadIdOrRemoteId); const state = this._state.value; const mappingId = createThreadMappingId(remoteMetadata.remoteId); const newThreadData = { ...state.threadData, [mappingId]: { id: mappingId, initializeTask: Promise.resolve({ remoteId: remoteMetadata.remoteId, externalId: remoteMetadata.externalId, }), remoteId: remoteMetadata.remoteId, externalId: remoteMetadata.externalId, status: remoteMetadata.status, title: remoteMetadata.title, }, }; const newThreadIdMap = { ...state.threadIdMap, [remoteMetadata.remoteId]: mappingId, }; const newThreadIds = remoteMetadata.status === "regular" ? [...state.threadIds, remoteMetadata.remoteId] : state.threadIds; const newArchivedThreadIds = remoteMetadata.status === "archived" ? [...state.archivedThreadIds, remoteMetadata.remoteId] : state.archivedThreadIds; this._state.update({ ...state, threadIds: newThreadIds, archivedThreadIds: newArchivedThreadIds, threadIdMap: newThreadIdMap, threadData: newThreadData, }); data = this.getItemById(threadIdOrRemoteId); } if (!data) throw new Error("Thread not found"); if (this._mainThreadId === data.id) return; const task = this._hookManager.startThreadRuntime(data.id); if (this.mainThreadId !== undefined) { await task; } else { task.then(() => this._notifySubscribers()); } if (data.status === "archived") await this.unarchive(data.id); this._mainThreadId = data.id; this._notifySubscribers(); } async switchToNewThread() { // an initialization transaction is in progress, wait for it to settle while (this._state.baseValue.newThreadId !== undefined && this._state.value.newThreadId === undefined) { await this._state.waitForUpdate(); } const state = this._state.value; let id = this._state.value.newThreadId; if (id === undefined) { do { id = `__LOCALID_${generateId()}`; } while (state.threadIdMap[id]); const mappingId = createThreadMappingId(id); this._state.update({ ...state, newThreadId: id, threadIdMap: { ...state.threadIdMap, [id]: mappingId, }, threadData: { ...state.threadData, [mappingId]: { status: "new", id, remoteId: undefined, externalId: undefined, title: undefined, }, }, }); } return this.switchToThread(id); } initialize = async (threadId) => { if (this._state.value.newThreadId !== threadId) { const data = this.getItemById(threadId); if (!data) throw new Error("Thread not found"); if (data.status === "new") throw new Error("Unexpected new state"); return data.initializeTask; } return this._state.optimisticUpdate({ execute: () => { return this._options.adapter.initialize(threadId); }, optimistic: (state) => { return updateStatusReducer(state, threadId, "regular"); }, loading: (state, task) => { const mappingId = createThreadMappingId(threadId); return { ...state, threadData: { ...state.threadData, [mappingId]: { ...state.threadData[mappingId], initializeTask: task, }, }, }; }, then: (state, { remoteId, externalId }) => { const data = getThreadData(state, threadId); if (!data) return state; const mappingId = createThreadMappingId(threadId); return { ...state, threadIdMap: { ...state.threadIdMap, [remoteId]: mappingId, }, threadData: { ...state.threadData, [mappingId]: { ...data, initializeTask: Promise.resolve({ remoteId, externalId }), remoteId, externalId, }, }, }; }, }); }; generateTitle = async (threadId) => { const data = this.getItemById(threadId); if (!data) throw new Error("Thread not found"); if (data.status === "new") throw new Error("Thread is not yet initialized"); const { remoteId } = await data.initializeTask; const runtimeCore = this._hookManager.getThreadRuntimeCore(data.id); if (!runtimeCore) return; // thread is no longer running const messages = runtimeCore.messages; const stream = await this._options.adapter.generateTitle(remoteId, messages); const messageStream = AssistantMessageStream.fromAssistantStream(stream); for await (const result of messageStream) { const newTitle = result.parts.filter((c) => c.type === "text")[0]?.text; const state = this._state.baseValue; this._state.update({ ...state, threadData: { ...state.threadData, [data.id]: { ...data, title: newTitle, }, }, }); } }; rename(threadIdOrRemoteId, newTitle) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); if (data.status === "new") throw new Error("Thread is not yet initialized"); return this._state.optimisticUpdate({ execute: async () => { const { remoteId } = await data.initializeTask; return this._options.adapter.rename(remoteId, newTitle); }, optimistic: (state) => { const data = getThreadData(state, threadIdOrRemoteId); if (!data) return state; return { ...state, threadData: { ...state.threadData, [data.id]: { ...data, title: newTitle, }, }, }; }, }); } async _ensureThreadIsNotMain(threadId) { if (threadId === this.newThreadId) throw new Error("Cannot ensure new thread is not main"); if (threadId === this._mainThreadId) { await this.switchToNewThread(); } } async archive(threadIdOrRemoteId) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); if (data.status !== "regular") throw new Error("Thread is not yet initialized or already archived"); return this._state.optimisticUpdate({ execute: async () => { await this._ensureThreadIsNotMain(data.id); const { remoteId } = await data.initializeTask; return this._options.adapter.archive(remoteId); }, optimistic: (state) => { return updateStatusReducer(state, data.id, "archived"); }, }); } unarchive(threadIdOrRemoteId) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); if (data.status !== "archived") throw new Error("Thread is not archived"); return this._state.optimisticUpdate({ execute: async () => { try { const { remoteId } = await data.initializeTask; return await this._options.adapter.unarchive(remoteId); } catch (error) { await this._ensureThreadIsNotMain(data.id); throw error; } }, optimistic: (state) => { return updateStatusReducer(state, data.id, "regular"); }, }); } async delete(threadIdOrRemoteId) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); if (data.status !== "regular" && data.status !== "archived") throw new Error("Thread is not yet initialized"); return this._state.optimisticUpdate({ execute: async () => { await this._ensureThreadIsNotMain(data.id); const { remoteId } = await data.initializeTask; return await this._options.adapter.delete(remoteId); }, optimistic: (state) => { return updateStatusReducer(state, data.id, "deleted"); }, }); } async detach(threadIdOrRemoteId) { const data = this.getItemById(threadIdOrRemoteId); if (!data) throw new Error("Thread not found"); if (data.status !== "regular" && data.status !== "archived") throw new Error("Thread is not yet initialized"); await this._ensureThreadIsNotMain(data.id); this._hookManager.stopThreadRuntime(data.id); } useBoundIds = create(() => []); __internal_RenderComponent = () => { const id = useId(); useEffect(() => { this.useBoundIds.setState((s) => [...s, id], true); return () => { this.useBoundIds.setState((s) => s.filter((i) => i !== id), true); }; }, [id]); const boundIds = this.useBoundIds(); const { Provider } = this.useProvider(); const adapters = { modelContext: this.contextProvider, }; return ((boundIds.length === 0 || boundIds[0] === id) && ( // only render if the component is the first one mounted _jsx(RuntimeAdapterProvider, { adapters: adapters, children: _jsx(this._hookManager.__internal_RenderThreadRuntimes, { provider: Provider }) }))); }; } //# sourceMappingURL=RemoteThreadListThreadListRuntimeCore.js.map