@ai-sdk/react
Version:
[React](https://react.dev/) UI components for the [AI SDK](https://ai-sdk.dev/docs):
131 lines (106 loc) • 3.49 kB
text/typescript
import { AbstractChat, ChatInit, ChatState, ChatStatus, UIMessage } from 'ai';
import { throttle } from './throttle';
class ReactChatState<UI_MESSAGE extends UIMessage>
implements ChatState<UI_MESSAGE>
{
#messages: UI_MESSAGE[];
#status: ChatStatus = 'ready';
#error: Error | undefined = undefined;
#messagesCallbacks = new Set<() => void>();
#statusCallbacks = new Set<() => void>();
#errorCallbacks = new Set<() => void>();
constructor(initialMessages: UI_MESSAGE[] = []) {
this.#messages = initialMessages;
}
get status(): ChatStatus {
return this.#status;
}
set status(newStatus: ChatStatus) {
this.#status = newStatus;
this.#callStatusCallbacks();
}
get error(): Error | undefined {
return this.#error;
}
set error(newError: Error | undefined) {
this.#error = newError;
this.#callErrorCallbacks();
}
get messages(): UI_MESSAGE[] {
return this.#messages;
}
set messages(newMessages: UI_MESSAGE[]) {
this.#messages = [...newMessages];
this.#callMessagesCallbacks();
}
pushMessage = (message: UI_MESSAGE) => {
this.#messages = this.#messages.concat(message);
this.#callMessagesCallbacks();
};
popMessage = () => {
this.#messages = this.#messages.slice(0, -1);
this.#callMessagesCallbacks();
};
replaceMessage = (index: number, message: UI_MESSAGE) => {
this.#messages = [
...this.#messages.slice(0, index),
// We deep clone the message here to ensure the new React Compiler (currently in RC) detects deeply nested parts/metadata changes:
this.snapshot(message),
...this.#messages.slice(index + 1),
];
this.#callMessagesCallbacks();
};
snapshot = <T>(value: T): T => structuredClone(value);
'~registerMessagesCallback' = (
onChange: () => void,
throttleWaitMs?: number,
): (() => void) => {
const callback = throttleWaitMs
? throttle(onChange, throttleWaitMs)
: onChange;
this.#messagesCallbacks.add(callback);
return () => {
this.#messagesCallbacks.delete(callback);
};
};
'~registerStatusCallback' = (onChange: () => void): (() => void) => {
this.#statusCallbacks.add(onChange);
return () => {
this.#statusCallbacks.delete(onChange);
};
};
'~registerErrorCallback' = (onChange: () => void): (() => void) => {
this.#errorCallbacks.add(onChange);
return () => {
this.#errorCallbacks.delete(onChange);
};
};
#callMessagesCallbacks = () => {
this.#messagesCallbacks.forEach(callback => callback());
};
#callStatusCallbacks = () => {
this.#statusCallbacks.forEach(callback => callback());
};
#callErrorCallbacks = () => {
this.#errorCallbacks.forEach(callback => callback());
};
}
export class Chat<
UI_MESSAGE extends UIMessage,
> extends AbstractChat<UI_MESSAGE> {
#state: ReactChatState<UI_MESSAGE>;
constructor({ messages, ...init }: ChatInit<UI_MESSAGE>) {
const state = new ReactChatState(messages);
super({ ...init, state });
this.#state = state;
}
'~registerMessagesCallback' = (
onChange: () => void,
throttleWaitMs?: number,
): (() => void) =>
this.#state['~registerMessagesCallback'](onChange, throttleWaitMs);
'~registerStatusCallback' = (onChange: () => void): (() => void) =>
this.#state['~registerStatusCallback'](onChange);
'~registerErrorCallback' = (onChange: () => void): (() => void) =>
this.#state['~registerErrorCallback'](onChange);
}