UNPKG

@assistant-ui/react

Version:

React components for AI chat.

547 lines (463 loc) 15.8 kB
import { AddToolResultOptions, ThreadSuggestion, RuntimeCapabilities, SubmitFeedbackOptions, ThreadRuntimeCore, SpeechState, SubmittedFeedback, ThreadRuntimeEventType, } from "../runtimes/core/ThreadRuntimeCore"; import { ExportedMessageRepository } from "../runtimes/utils/MessageRepository"; import { AppendMessage, ModelConfig, ThreadMessage, Unsubscribe, } from "../types"; import { MessageRuntime, MessageRuntimeImpl, MessageState, } from "./MessageRuntime"; import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject"; import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject"; import { Subscribable, SubscribableWithState, } from "./subscribable/Subscribable"; import { ThreadComposerRuntime, ThreadComposerRuntimeImpl, } from "./ComposerRuntime"; import { LazyMemoizeSubject } from "./subscribable/LazyMemoizeSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; import { ComposerRuntimeCore } from "../runtimes/core/ComposerRuntimeCore"; import { MessageRuntimePath, ThreadRuntimePath } from "./RuntimePathTypes"; export type CreateAppendMessage = | string | { parentId?: string | null | undefined; role?: AppendMessage["role"] | undefined; content: AppendMessage["content"]; attachments?: AppendMessage["attachments"] | undefined; startRun?: boolean | undefined; }; const toAppendMessage = ( messages: readonly ThreadMessage[], message: CreateAppendMessage, ): AppendMessage => { if (typeof message === "string") { return { parentId: messages.at(-1)?.id ?? null, role: "user", content: [{ type: "text", text: message }], attachments: [], }; } if (message.role && message.parentId && message.attachments) { return message as AppendMessage; } return { ...message, parentId: message.parentId ?? messages.at(-1)?.id ?? null, role: message.role ?? "user", attachments: message.attachments ?? [], } as AppendMessage; }; export type ThreadRuntimeCoreBinding = SubscribableWithState< ThreadRuntimeCore, ThreadRuntimePath > & { outerSubscribe(callback: () => void): Unsubscribe; }; export type ThreadState = Readonly<{ /** * The thread ID. */ threadId: string; /** * Whether the thread is disabled. Disabled threads cannot receive new messages. */ isDisabled: boolean; /** * Whether the thread is running. A thread is considered running when there is an active stream connection to the backend. */ isRunning: boolean; /** * The capabilities of the thread, such as whether the thread supports editing, branch switching, etc. */ capabilities: RuntimeCapabilities; /** * The messages in the currently selected branch of the thread. */ messages: readonly ThreadMessage[]; /** * Follow up message suggestions to show the user. */ suggestions: readonly ThreadSuggestion[]; /** * Custom extra information provided by the runtime. */ extras: unknown; /** * @deprecated This API is still under active development and might change without notice. */ speech: SpeechState | undefined; }>; export const getThreadState = (runtime: ThreadRuntimeCore): ThreadState => { const lastMessage = runtime.messages.at(-1); return Object.freeze({ threadId: runtime.threadId, capabilities: runtime.capabilities, isDisabled: runtime.isDisabled, isRunning: lastMessage?.role !== "assistant" ? false : lastMessage.status.type === "running", messages: runtime.messages, suggestions: runtime.suggestions, extras: runtime.extras, speech: runtime.speech, }); }; export type ThreadRuntime = Readonly<{ readonly path: ThreadRuntimePath; readonly composer: ThreadComposerRuntime; getState(): ThreadState; /** * @deprecated This method will be removed in 0.6.0. Submit feedback if you need this functionality. */ unstable_getCore(): ThreadRuntimeCore; append(message: CreateAppendMessage): void; startRun(parentId: string | null): void; subscribe(callback: () => void): Unsubscribe; cancelRun(): void; getModelConfig(): ModelConfig; export(): ExportedMessageRepository; import(repository: ExportedMessageRepository): void; getMesssageByIndex(idx: number): MessageRuntime; getMesssageById(messageId: string): MessageRuntime; /** * @deprecated This API is still under active development and might change without notice. */ stopSpeaking: () => void; unstable_on(event: ThreadRuntimeEventType, callback: () => void): Unsubscribe; // Legacy methods with deprecations /** * @deprecated Use `getState().capabilities` instead. This will be removed in 0.6.0. */ capabilities: Readonly<RuntimeCapabilities>; /** * @deprecated Use `getState().threadId` instead. This will be removed in 0.6.0. */ threadId: string; /** * @deprecated Use `getState().isDisabled` instead. This will be removed in 0.6.0. */ isDisabled: boolean; /** * @deprecated Use `getState().isRunning` instead. This will be removed in 0.6.0. */ isRunning: boolean; /** * @deprecated Use `getState().messages` instead. This will be removed in 0.6.0. */ messages: readonly ThreadMessage[]; /** * @deprecated Use `getState().followupSuggestions` instead. This will be removed in 0.6.0. */ suggestions: readonly ThreadSuggestion[]; /** * @deprecated Use `getState().speechState` instead. This will be removed in 0.6.0. */ speech: SpeechState | undefined; /** * @deprecated Use `getState().extras` instead. This will be removed in 0.6.0. */ extras: unknown; /** * @deprecated Use `getMesssageById(id).getState().branchNumber` / `getMesssageById(id).getState().branchCount` instead. This will be removed in 0.6.0. */ getBranches: (messageId: string) => readonly string[]; /** * @deprecated Use `getMesssageById(id).switchToBranch({ options })` instead. This will be removed in 0.6.0. */ switchToBranch: (branchId: string) => void; /** * @deprecated Use `getMesssageById(id).getContentPartByToolCallId(toolCallId).addToolResult({ result })` instead. This will be removed in 0.6.0. */ addToolResult: (options: AddToolResultOptions) => void; /** * @deprecated Use `getMesssageById(id).speak()` instead. This will be removed in 0.6.0. */ speak: (messageId: string) => void; /** * @deprecated Use `getMesssageById(id).getState().submittedFeedback` instead. This will be removed in 0.6.0. */ getSubmittedFeedback: (messageId: string) => SubmittedFeedback | undefined; /** * @deprecated Use `getMesssageById(id).submitFeedback({ type })` instead. This will be removed in 0.6.0. */ submitFeedback: (feedback: SubmitFeedbackOptions) => void; /** * @deprecated Use `getMesssageById(id).composer` instead. This will be removed in 0.6.0. */ getEditComposer: (messageId: string) => ComposerRuntimeCore | undefined; /** * @deprecated Use `getMesssageById(id).composer.beginEdit()` instead. This will be removed in 0.6.0. */ beginEdit: (messageId: string) => void; }>; export class ThreadRuntimeImpl implements Omit<ThreadRuntimeCore, "getMessageById">, ThreadRuntime { public get path() { return this._threadBinding.path; } /** * @deprecated Use `getState().threadId` instead. This will be removed in 0.6.0. */ public get threadId() { return this.getState().threadId; } /** * @deprecated Use `getState().isDisabled` instead. This will be removed in 0.6.0. */ public get isDisabled() { return this.getState().isDisabled; } /** * @deprecated Use `getState().isRunning` instead. This will be removed in 0.6.0. */ public get isRunning() { return this.getState().isRunning; } /** * @deprecated Use `getState().capabilities` instead. This will be removed in 0.6.0. */ public get capabilities() { return this.getState().capabilities; } /** * @deprecated Use `getState().extras` instead. This will be removed in 0.6.0. */ public get extras() { return this._threadBinding.getState().extras; } /** * @deprecated Use `getState().followupSuggestions` instead. This will be removed in 0.6.0. */ public get suggestions() { return this._threadBinding.getState().suggestions; } /** * @deprecated Use `getState().messages` instead. This will be removed in 0.6.0. */ public get messages() { return this._threadBinding.getState().messages; } /** * @deprecated Use `getState().speechState` instead. This will be removed in 0.6.0. */ public get speech() { return this._threadBinding.getState().speech; } public unstable_getCore() { return this._threadBinding.getState(); } private _threadBinding: ThreadRuntimeCoreBinding & { getStateState(): ThreadState; }; constructor(threadBinding: ThreadRuntimeCoreBinding) { const stateBinding = new LazyMemoizeSubject({ path: threadBinding.path, getState: () => getThreadState(threadBinding.getState()), subscribe: (callback) => threadBinding.subscribe(callback), }); this._threadBinding = { path: threadBinding.path, getState: () => threadBinding.getState(), getStateState: () => stateBinding.getState(), outerSubscribe: (callback) => threadBinding.outerSubscribe(callback), subscribe: (callback) => threadBinding.subscribe(callback), }; this.composer = new ThreadComposerRuntimeImpl( new NestedSubscriptionSubject({ path: { ...this.path, ref: this.path.ref + `${this.path.ref}.composer`, composerSource: "thread", }, getState: () => this._threadBinding.getState().composer, subscribe: (callback) => this._threadBinding.subscribe(callback), }), ); } public readonly composer; public getState() { return this._threadBinding.getStateState(); } public append(message: CreateAppendMessage) { this._threadBinding .getState() .append( toAppendMessage(this._threadBinding.getState().messages, message), ); } public subscribe(callback: () => void) { return this._threadBinding.subscribe(callback); } /** * @derprecated Use `getMesssageById(id).getState().branchNumber` / `getMesssageById(id).getState().branchCount` instead. This will be removed in 0.6.0. */ public getBranches(messageId: string) { return this._threadBinding.getState().getBranches(messageId); } public getModelConfig() { return this._threadBinding.getState().getModelConfig(); } // TODO sometimes you want to continue when there is no child message public startRun(parentId: string | null) { return this._threadBinding.getState().startRun(parentId); } public cancelRun() { this._threadBinding.getState().cancelRun(); } /** * @deprecated Use `getMesssageById(id).getContentPartByToolCallId(toolCallId).addToolResult({ result })` instead. This will be removed in 0.6.0. */ public addToolResult(options: AddToolResultOptions) { this._threadBinding.getState().addToolResult(options); } /** * @deprecated Use `getMesssageById(id).switchToBranch({ options })` instead. This will be removed in 0.6.0. */ public switchToBranch(branchId: string) { return this._threadBinding.getState().switchToBranch(branchId); } /** * @deprecated Use `getMesssageById(id).speak()` instead. This will be removed in 0.6.0. */ public speak(messageId: string) { return this._threadBinding.getState().speak(messageId); } public stopSpeaking() { return this._threadBinding.getState().stopSpeaking(); } public getSubmittedFeedback(messageId: string) { return this._threadBinding.getState().getSubmittedFeedback(messageId); } /** * @deprecated Use `getMesssageById(id).submitFeedback({ type })` instead. This will be removed in 0.6.0. */ public submitFeedback(options: SubmitFeedbackOptions) { return this._threadBinding.getState().submitFeedback(options); } /** * @deprecated Use `getMesssageById(id).getMessageByIndex(idx).composer` instead. This will be removed in 0.6.0. */ public getEditComposer(messageId: string) { return this._threadBinding.getState().getEditComposer(messageId); } /** * @deprecated Use `getMesssageById(id).getMessageByIndex(idx).composer.beginEdit()` instead. This will be removed in 0.6.0. */ public beginEdit(messageId: string) { return this._threadBinding.getState().beginEdit(messageId); } public export() { return this._threadBinding.getState().export(); } public import(data: ExportedMessageRepository) { this._threadBinding.getState().import(data); } public getMesssageByIndex(idx: number) { if (idx < 0) throw new Error("Message index must be >= 0"); return this._getMessageRuntime( { ...this.path, ref: this.path.ref + `${this.path.ref}.messages[${idx}]`, messageSelector: { type: "index", index: idx }, }, () => { const messages = this._threadBinding.getState().messages; const message = messages[idx]; if (!message) return undefined; return { message, parentId: messages[idx - 1]?.id ?? null, }; }, ); } public getMesssageById(messageId: string) { return this._getMessageRuntime( { ...this.path, ref: this.path.ref + `${this.path.ref}.messages[messageId=${JSON.stringify(messageId)}]`, messageSelector: { type: "messageId", messageId: messageId }, }, () => this._threadBinding.getState().getMessageById(messageId), ); } private _getMessageRuntime( path: MessageRuntimePath, callback: () => | { parentId: string | null; message: ThreadMessage } | undefined, ) { return new MessageRuntimeImpl( new ShallowMemoizeSubject({ path, getState: () => { const { message, parentId } = callback() ?? {}; const { messages, speech: speechState } = this._threadBinding.getState(); if (!message || parentId === undefined) return SKIP_UPDATE; const thread = this._threadBinding.getState(); const branches = thread.getBranches(message.id); const submittedFeedback = thread.getSubmittedFeedback(message.id); return { ...message, message, isLast: messages.at(-1)?.id === message.id, parentId, branches, branchNumber: branches.indexOf(message.id) + 1, branchCount: branches.length, speech: speechState?.messageId === message.id ? speechState : undefined, submittedFeedback, } satisfies MessageState; }, subscribe: (callback) => this._threadBinding.subscribe(callback), }), this._threadBinding, ); } private _eventListenerNestedSubscriptions = new Map< string, NestedSubscriptionSubject<Subscribable, ThreadRuntimePath> >(); public unstable_on( event: ThreadRuntimeEventType, callback: () => void, ): Unsubscribe { let subject = this._eventListenerNestedSubscriptions.get(event); if (!subject) { subject = new NestedSubscriptionSubject({ path: this.path, getState: () => ({ subscribe: (callback) => this._threadBinding.getState().unstable_on(event, callback), }), subscribe: (callback) => this._threadBinding.outerSubscribe(callback), }); this._eventListenerNestedSubscriptions.set(event, subject); } return subject.subscribe(callback); } }