UNPKG

@assistant-ui/react

Version:

TypeScript/React library for AI Chat

235 lines 9.3 kB
import { MessageRuntimeImpl, } from "./MessageRuntime.js"; import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject.js"; import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject.js"; import { ThreadComposerRuntimeImpl, } from "./ComposerRuntime.js"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE.js"; import { EventSubscriptionSubject } from "./subscribable/EventSubscriptionSubject.js"; import { symbolInnerMessage } from "../runtime-cores/external-store/getExternalStoreMessage.js"; const toResumeRunConfig = (message) => { return { parentId: message.parentId ?? null, sourceId: message.sourceId ?? null, runConfig: message.runConfig ?? {}, ...(message.stream ? { stream: message.stream } : {}), }; }; const toStartRunConfig = (message) => { return { parentId: message.parentId ?? null, sourceId: message.sourceId ?? null, runConfig: message.runConfig ?? {}, }; }; const toAppendMessage = (messages, message) => { if (typeof message === "string") { return { createdAt: new Date(), parentId: messages.at(-1)?.id ?? null, sourceId: null, runConfig: {}, role: "user", content: [{ type: "text", text: message }], attachments: [], metadata: { custom: {} }, }; } return { createdAt: message.createdAt ?? new Date(), parentId: message.parentId ?? messages.at(-1)?.id ?? null, sourceId: message.sourceId ?? null, role: message.role ?? "user", content: message.content, attachments: message.attachments ?? [], metadata: message.metadata ?? { custom: {} }, runConfig: message.runConfig ?? {}, startRun: message.startRun, }; }; export const getThreadState = (runtime, threadListItemState) => { const lastMessage = runtime.messages.at(-1); return Object.freeze({ threadId: threadListItemState.id, metadata: threadListItemState, capabilities: runtime.capabilities, isDisabled: runtime.isDisabled, isLoading: runtime.isLoading, isRunning: lastMessage?.role !== "assistant" ? false : lastMessage.status.type === "running", messages: runtime.messages, state: runtime.state, suggestions: runtime.suggestions, extras: runtime.extras, speech: runtime.speech, }); }; export class ThreadRuntimeImpl { get path() { return this._threadBinding.path; } get __internal_threadBinding() { return this._threadBinding; } _threadBinding; constructor(threadBinding, threadListItemBinding) { const stateBinding = new ShallowMemoizeSubject({ path: threadBinding.path, getState: () => getThreadState(threadBinding.getState(), threadListItemBinding.getState()), subscribe: (callback) => { const sub1 = threadBinding.subscribe(callback); const sub2 = threadListItemBinding.subscribe(callback); return () => { sub1(); sub2(); }; }, }); this._threadBinding = { path: threadBinding.path, getState: () => threadBinding.getState(), getStateState: () => stateBinding.getState(), outerSubscribe: (callback) => threadBinding.outerSubscribe(callback), subscribe: (callback) => threadBinding.subscribe(callback), }; this.composer = new ThreadComposerRuntimeImpl(new NestedSubscriptionSubject({ path: { ...this.path, ref: `${this.path.ref}${this.path.ref}.composer`, composerSource: "thread", }, getState: () => this._threadBinding.getState().composer, subscribe: (callback) => this._threadBinding.subscribe(callback), })); this.__internal_bindMethods(); } __internal_bindMethods() { this.append = this.append.bind(this); this.unstable_resumeRun = this.unstable_resumeRun.bind(this); this.unstable_loadExternalState = this.unstable_loadExternalState.bind(this); this.startRun = this.startRun.bind(this); this.cancelRun = this.cancelRun.bind(this); this.stopSpeaking = this.stopSpeaking.bind(this); this.export = this.export.bind(this); this.import = this.import.bind(this); this.reset = this.reset.bind(this); this.getMessageByIndex = this.getMessageByIndex.bind(this); this.getMessageById = this.getMessageById.bind(this); this.subscribe = this.subscribe.bind(this); this.unstable_on = this.unstable_on.bind(this); this.getModelContext = this.getModelContext.bind(this); this.getModelConfig = this.getModelConfig.bind(this); this.getState = this.getState.bind(this); } composer; getState() { return this._threadBinding.getStateState(); } append(message) { this._threadBinding .getState() .append(toAppendMessage(this._threadBinding.getState().messages, message)); } subscribe(callback) { return this._threadBinding.subscribe(callback); } getModelContext() { return this._threadBinding.getState().getModelContext(); } getModelConfig() { return this.getModelContext(); } startRun(configOrParentId) { const config = configOrParentId === null || typeof configOrParentId === "string" ? { parentId: configOrParentId } : configOrParentId; return this._threadBinding.getState().startRun(toStartRunConfig(config)); } unstable_resumeRun(config) { return this._threadBinding.getState().resumeRun(toResumeRunConfig(config)); } unstable_loadExternalState(state) { this._threadBinding.getState().unstable_loadExternalState(state); } cancelRun() { this._threadBinding.getState().cancelRun(); } stopSpeaking() { return this._threadBinding.getState().stopSpeaking(); } export() { return this._threadBinding.getState().export(); } import(data) { this._threadBinding.getState().import(data); } reset(initialMessages) { this._threadBinding.getState().reset(initialMessages); } getMessageByIndex(idx) { if (idx < 0) throw new Error("Message index must be >= 0"); return this._getMessageRuntime({ ...this.path, ref: `${this.path.ref}${this.path.ref}.messages[${idx}]`, messageSelector: { type: "index", index: idx }, }, () => { const messages = this._threadBinding.getState().messages; const message = messages[idx]; if (!message) return undefined; return { message, parentId: messages[idx - 1]?.id ?? null, index: idx, }; }); } getMessageById(messageId) { return this._getMessageRuntime({ ...this.path, ref: this.path.ref + `${this.path.ref}.messages[messageId=${JSON.stringify(messageId)}]`, messageSelector: { type: "messageId", messageId: messageId }, }, () => this._threadBinding.getState().getMessageById(messageId)); } _getMessageRuntime(path, callback) { return new MessageRuntimeImpl(new ShallowMemoizeSubject({ path, getState: () => { const { message, parentId, index } = callback() ?? {}; const { messages, speech: speechState } = this._threadBinding.getState(); if (!message || parentId === undefined || index === undefined) return SKIP_UPDATE; const thread = this._threadBinding.getState(); const branches = thread.getBranches(message.id); const submittedFeedback = message.metadata.submittedFeedback; return { ...message, ...{ [symbolInnerMessage]: message[symbolInnerMessage] }, index, isLast: messages.at(-1)?.id === message.id, parentId, branchNumber: branches.indexOf(message.id) + 1, branchCount: branches.length, speech: speechState?.messageId === message.id ? speechState : undefined, submittedFeedback, }; }, subscribe: (callback) => this._threadBinding.subscribe(callback), }), this._threadBinding); } _eventSubscriptionSubjects = new Map(); unstable_on(event, callback) { let subject = this._eventSubscriptionSubjects.get(event); if (!subject) { subject = new EventSubscriptionSubject({ event: event, binding: this._threadBinding, }); this._eventSubscriptionSubjects.set(event, subject); } return subject.subscribe(callback); } } //# sourceMappingURL=ThreadRuntime.js.map