UNPKG

@assistant-ui/react

Version:

TypeScript/React library for AI Chat

201 lines 8.02 kB
import { symbolInnerMessage } from "../runtime-cores/external-store/getExternalStoreMessage.js"; import { getThreadMessageText } from "../../utils/getThreadMessageText.js"; import { MessageAttachmentRuntimeImpl, } from "./AttachmentRuntime.js"; import { EditComposerRuntimeImpl, } from "./ComposerRuntime.js"; import { MessagePartRuntimeImpl, } from "./MessagePartRuntime.js"; import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject.js"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE.js"; import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject.js"; const COMPLETE_STATUS = Object.freeze({ type: "complete", }); export const toMessagePartStatus = (message, partIndex, part) => { if (message.role !== "assistant") return COMPLETE_STATUS; if (part.type === "tool-call") { if (!part.result) { return message.status; } else { return COMPLETE_STATUS; } } const isLastPart = partIndex === Math.max(0, message.content.length - 1); if (message.status.type === "requires-action") return COMPLETE_STATUS; return isLastPart ? message.status : COMPLETE_STATUS; }; const getMessagePartState = (message, partIndex) => { const part = message.content[partIndex]; if (!part) { return SKIP_UPDATE; } // if the message part is the same, don't update const status = toMessagePartStatus(message, partIndex, part); return Object.freeze({ ...part, ...{ [symbolInnerMessage]: part[symbolInnerMessage] }, status, }); }; export class MessageRuntimeImpl { _core; _threadBinding; get path() { return this._core.path; } constructor(_core, _threadBinding) { this._core = _core; this._threadBinding = _threadBinding; this.composer = new EditComposerRuntimeImpl(new NestedSubscriptionSubject({ path: { ...this.path, ref: `${this.path.ref}${this.path.ref}.composer`, composerSource: "edit", }, getState: this._getEditComposerRuntimeCore, subscribe: (callback) => this._threadBinding.subscribe(callback), }), () => this._threadBinding.getState().beginEdit(this._core.getState().id)); this.__internal_bindMethods(); } __internal_bindMethods() { this.reload = this.reload.bind(this); this.getState = this.getState.bind(this); this.subscribe = this.subscribe.bind(this); this.getMessagePartByIndex = this.getMessagePartByIndex.bind(this); this.getMessagePartByToolCallId = this.getMessagePartByToolCallId.bind(this); this.getAttachmentByIndex = this.getAttachmentByIndex.bind(this); this.unstable_getCopyText = this.unstable_getCopyText.bind(this); this.speak = this.speak.bind(this); this.stopSpeaking = this.stopSpeaking.bind(this); this.submitFeedback = this.submitFeedback.bind(this); this.switchToBranch = this.switchToBranch.bind(this); } composer; _getEditComposerRuntimeCore = () => { return this._threadBinding .getState() .getEditComposer(this._core.getState().id); }; getState() { return this._core.getState(); } reload(reloadConfig = {}) { const editComposerRuntimeCore = this._getEditComposerRuntimeCore(); const composerRuntimeCore = editComposerRuntimeCore ?? this._threadBinding.getState().composer; const composer = editComposerRuntimeCore ?? composerRuntimeCore; const { runConfig = composer.runConfig } = reloadConfig; const state = this._core.getState(); if (state.role !== "assistant") throw new Error("Can only reload assistant messages"); this._threadBinding.getState().startRun({ parentId: state.parentId, sourceId: state.id, runConfig, }); } speak() { const state = this._core.getState(); return this._threadBinding.getState().speak(state.id); } stopSpeaking() { const state = this._core.getState(); const thread = this._threadBinding.getState(); if (thread.speech?.messageId === state.id) { this._threadBinding.getState().stopSpeaking(); } else { throw new Error("Message is not being spoken"); } } submitFeedback({ type }) { const state = this._core.getState(); this._threadBinding.getState().submitFeedback({ messageId: state.id, type, }); } switchToBranch({ position, branchId, }) { const state = this._core.getState(); if (branchId && position) { throw new Error("May not specify both branchId and position"); } else if (!branchId && !position) { throw new Error("Must specify either branchId or position"); } const thread = this._threadBinding.getState(); const branches = thread.getBranches(state.id); let targetBranch = branchId; if (position === "previous") { targetBranch = branches[state.branchNumber - 2]; } else if (position === "next") { targetBranch = branches[state.branchNumber]; } if (!targetBranch) throw new Error("Branch not found"); this._threadBinding.getState().switchToBranch(targetBranch); } unstable_getCopyText() { return getThreadMessageText(this.getState()); } subscribe(callback) { return this._core.subscribe(callback); } getMessagePartByIndex(idx) { if (idx < 0) throw new Error("Message part index must be >= 0"); return new MessagePartRuntimeImpl(new ShallowMemoizeSubject({ path: { ...this.path, ref: `${this.path.ref}${this.path.ref}.content[${idx}]`, messagePartSelector: { type: "index", index: idx }, }, getState: () => { return getMessagePartState(this.getState(), idx); }, subscribe: (callback) => this._core.subscribe(callback), }), this._core, this._threadBinding); } getMessagePartByToolCallId(toolCallId) { return new MessagePartRuntimeImpl(new ShallowMemoizeSubject({ path: { ...this.path, ref: this.path.ref + `${this.path.ref}.content[toolCallId=${JSON.stringify(toolCallId)}]`, messagePartSelector: { type: "toolCallId", toolCallId }, }, getState: () => { const state = this._core.getState(); const idx = state.content.findIndex((part) => part.type === "tool-call" && part.toolCallId === toolCallId); if (idx === -1) return SKIP_UPDATE; return getMessagePartState(state, idx); }, subscribe: (callback) => this._core.subscribe(callback), }), this._core, this._threadBinding); } getAttachmentByIndex(idx) { return new MessageAttachmentRuntimeImpl(new ShallowMemoizeSubject({ path: { ...this.path, ref: `${this.path.ref}${this.path.ref}.attachments[${idx}]`, attachmentSource: "message", attachmentSelector: { type: "index", index: idx }, }, getState: () => { const attachments = this.getState().attachments; const attachment = attachments?.[idx]; if (!attachment) return SKIP_UPDATE; return { ...attachment, source: "message", }; }, subscribe: (callback) => this._core.subscribe(callback), })); } } //# sourceMappingURL=MessageRuntime.js.map