UNPKG

@assistant-ui/react

Version:

Typescript/React library for AI Chat

254 lines (253 loc) 9.08 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); // src/api/ThreadRuntime.ts var ThreadRuntime_exports = {}; __export(ThreadRuntime_exports, { ThreadRuntimeImpl: () => ThreadRuntimeImpl, getThreadState: () => getThreadState }); module.exports = __toCommonJS(ThreadRuntime_exports); var import_MessageRuntime = require("./MessageRuntime.js"); var import_NestedSubscriptionSubject = require("./subscribable/NestedSubscriptionSubject.js"); var import_ShallowMemoizeSubject = require("./subscribable/ShallowMemoizeSubject.js"); var import_ComposerRuntime = require("./ComposerRuntime.js"); var import_SKIP_UPDATE = require("./subscribable/SKIP_UPDATE.js"); var import_EventSubscriptionSubject = require("./subscribable/EventSubscriptionSubject.js"); var import_getExternalStoreMessage = require("../runtimes/external-store/getExternalStoreMessage.js"); var toResumeRunConfig = (message) => { return { parentId: message.parentId ?? null, sourceId: message.sourceId ?? null, runConfig: message.runConfig ?? {}, stream: message.stream }; }; var toStartRunConfig = (message) => { return { parentId: message.parentId ?? null, sourceId: message.sourceId ?? null, runConfig: message.runConfig ?? {} }; }; var toAppendMessage = (messages, message) => { if (typeof message === "string") { return { parentId: messages.at(-1)?.id ?? null, sourceId: null, runConfig: {}, role: "user", content: [{ type: "text", text: message }], attachments: [] }; } if (message.role && message.parentId && message.attachments) { return message; } return { ...message, parentId: message.parentId ?? messages.at(-1)?.id ?? null, sourceId: message.sourceId ?? null, role: message.role ?? "user", attachments: message.attachments ?? [] }; }; var getThreadState = (runtime, threadListItemState) => { const lastMessage = runtime.messages.at(-1); return Object.freeze({ threadId: threadListItemState.id, metadata: threadListItemState, capabilities: runtime.capabilities, isDisabled: runtime.isDisabled, isRunning: lastMessage?.role !== "assistant" ? false : lastMessage.status.type === "running", messages: runtime.messages, suggestions: runtime.suggestions, extras: runtime.extras, speech: runtime.speech }); }; var ThreadRuntimeImpl = class { get path() { return this._threadBinding.path; } get __internal_threadBinding() { return this._threadBinding; } _threadBinding; constructor(threadBinding, threadListItemBinding) { const stateBinding = new import_ShallowMemoizeSubject.ShallowMemoizeSubject({ path: threadBinding.path, getState: () => getThreadState( threadBinding.getState(), threadListItemBinding.getState() ), subscribe: (callback) => { const sub1 = threadBinding.subscribe(callback); const sub2 = threadListItemBinding.subscribe(callback); return () => { sub1(); sub2(); }; } }); this._threadBinding = { path: threadBinding.path, getState: () => threadBinding.getState(), getStateState: () => stateBinding.getState(), outerSubscribe: (callback) => threadBinding.outerSubscribe(callback), subscribe: (callback) => threadBinding.subscribe(callback) }; this.composer = new import_ComposerRuntime.ThreadComposerRuntimeImpl( new import_NestedSubscriptionSubject.NestedSubscriptionSubject({ path: { ...this.path, ref: this.path.ref + `${this.path.ref}.composer`, composerSource: "thread" }, getState: () => this._threadBinding.getState().composer, subscribe: (callback) => this._threadBinding.subscribe(callback) }) ); } __internal_bindMethods() { this.append = this.append.bind(this); this.startRun = this.startRun.bind(this); this.cancelRun = this.cancelRun.bind(this); this.stopSpeaking = this.stopSpeaking.bind(this); this.export = this.export.bind(this); this.import = this.import.bind(this); this.getMesssageByIndex = this.getMesssageByIndex.bind(this); this.getMesssageById = this.getMesssageById.bind(this); this.subscribe = this.subscribe.bind(this); this.unstable_on = this.unstable_on.bind(this); this.getModelContext = this.getModelContext.bind(this); this.getModelConfig = this.getModelConfig.bind(this); this.getState = this.getState.bind(this); } composer; getState() { return this._threadBinding.getStateState(); } append(message) { this._threadBinding.getState().append( toAppendMessage(this._threadBinding.getState().messages, message) ); } subscribe(callback) { return this._threadBinding.subscribe(callback); } getModelContext() { return this._threadBinding.getState().getModelContext(); } getModelConfig() { return this.getModelContext(); } startRun(configOrParentId) { const config = configOrParentId === null || typeof configOrParentId === "string" ? { parentId: configOrParentId } : configOrParentId; return this._threadBinding.getState().startRun(toStartRunConfig(config)); } unstable_resumeRun(config) { return this._threadBinding.getState().resumeRun(toResumeRunConfig(config)); } cancelRun() { this._threadBinding.getState().cancelRun(); } stopSpeaking() { return this._threadBinding.getState().stopSpeaking(); } export() { return this._threadBinding.getState().export(); } import(data) { this._threadBinding.getState().import(data); } getMesssageByIndex(idx) { if (idx < 0) throw new Error("Message index must be >= 0"); return this._getMessageRuntime( { ...this.path, ref: this.path.ref + `${this.path.ref}.messages[${idx}]`, messageSelector: { type: "index", index: idx } }, () => { const messages = this._threadBinding.getState().messages; const message = messages[idx]; if (!message) return void 0; return { message, parentId: messages[idx - 1]?.id ?? null }; } ); } getMesssageById(messageId) { return this._getMessageRuntime( { ...this.path, ref: this.path.ref + `${this.path.ref}.messages[messageId=${JSON.stringify(messageId)}]`, messageSelector: { type: "messageId", messageId } }, () => this._threadBinding.getState().getMessageById(messageId) ); } _getMessageRuntime(path, callback) { return new import_MessageRuntime.MessageRuntimeImpl( new import_ShallowMemoizeSubject.ShallowMemoizeSubject({ path, getState: () => { const { message, parentId } = callback() ?? {}; const { messages, speech: speechState } = this._threadBinding.getState(); if (!message || parentId === void 0) return import_SKIP_UPDATE.SKIP_UPDATE; const thread = this._threadBinding.getState(); const branches = thread.getBranches(message.id); const submittedFeedback = thread.getSubmittedFeedback(message.id); return { ...message, ...{ [import_getExternalStoreMessage.symbolInnerMessage]: message[import_getExternalStoreMessage.symbolInnerMessage] }, isLast: messages.at(-1)?.id === message.id, parentId, branchNumber: branches.indexOf(message.id) + 1, branchCount: branches.length, speech: speechState?.messageId === message.id ? speechState : void 0, submittedFeedback }; }, subscribe: (callback2) => this._threadBinding.subscribe(callback2) }), this._threadBinding ); } _eventSubscriptionSubjects = /* @__PURE__ */ new Map(); unstable_on(event, callback) { let subject = this._eventSubscriptionSubjects.get(event); if (!subject) { subject = new import_EventSubscriptionSubject.EventSubscriptionSubject({ event, binding: this._threadBinding }); this._eventSubscriptionSubjects.set(event, subject); } return subject.subscribe(callback); } }; // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { ThreadRuntimeImpl, getThreadState }); //# sourceMappingURL=ThreadRuntime.js.map