UNPKG

@ai-sdk/react

Version:

[React](https://react.dev/) UI components for the [AI SDK](https://ai-sdk.dev/docs):

466 lines (459 loc) 14 kB
var __accessCheck = (obj, member, msg) => { if (!member.has(obj)) throw TypeError("Cannot " + msg); }; var __privateGet = (obj, member, getter) => { __accessCheck(obj, member, "read from private field"); return getter ? getter.call(obj) : member.get(obj); }; var __privateAdd = (obj, member, value) => { if (member.has(obj)) throw TypeError("Cannot add the same private member more than once"); member instanceof WeakSet ? member.add(obj) : member.set(obj, value); }; var __privateSet = (obj, member, value, setter) => { __accessCheck(obj, member, "write to private field"); setter ? setter.call(obj, value) : member.set(obj, value); return value; }; // src/use-chat.ts import { useCallback, useEffect, useRef, useSyncExternalStore } from "react"; // src/chat.react.ts import { AbstractChat } from "ai"; // src/throttle.ts import throttleFunction from "throttleit"; function throttle(fn, waitMs) { return waitMs != null ? throttleFunction(fn, waitMs) : fn; } // src/chat.react.ts var _messages, _status, _error, _messagesCallbacks, _statusCallbacks, _errorCallbacks, _callMessagesCallbacks, _callStatusCallbacks, _callErrorCallbacks; var ReactChatState = class { constructor(initialMessages = []) { __privateAdd(this, _messages, void 0); __privateAdd(this, _status, "ready"); __privateAdd(this, _error, void 0); __privateAdd(this, _messagesCallbacks, /* @__PURE__ */ new Set()); __privateAdd(this, _statusCallbacks, /* @__PURE__ */ new Set()); __privateAdd(this, _errorCallbacks, /* @__PURE__ */ new Set()); this.pushMessage = (message) => { __privateSet(this, _messages, __privateGet(this, _messages).concat(message)); __privateGet(this, _callMessagesCallbacks).call(this); }; this.popMessage = () => { __privateSet(this, _messages, __privateGet(this, _messages).slice(0, -1)); __privateGet(this, _callMessagesCallbacks).call(this); }; this.replaceMessage = (index, message) => { __privateSet(this, _messages, [ ...__privateGet(this, _messages).slice(0, index), // We deep clone the message here to ensure the new React Compiler (currently in RC) detects deeply nested parts/metadata changes: this.snapshot(message), ...__privateGet(this, _messages).slice(index + 1) ]); __privateGet(this, _callMessagesCallbacks).call(this); }; this.snapshot = (value) => structuredClone(value); this["~registerMessagesCallback"] = (onChange, throttleWaitMs) => { const callback = throttleWaitMs ? throttle(onChange, throttleWaitMs) : onChange; __privateGet(this, _messagesCallbacks).add(callback); return () => { __privateGet(this, _messagesCallbacks).delete(callback); }; }; this["~registerStatusCallback"] = (onChange) => { __privateGet(this, _statusCallbacks).add(onChange); return () => { __privateGet(this, _statusCallbacks).delete(onChange); }; }; this["~registerErrorCallback"] = (onChange) => { __privateGet(this, _errorCallbacks).add(onChange); return () => { __privateGet(this, _errorCallbacks).delete(onChange); }; }; __privateAdd(this, _callMessagesCallbacks, () => { __privateGet(this, _messagesCallbacks).forEach((callback) => callback()); }); __privateAdd(this, _callStatusCallbacks, () => { __privateGet(this, _statusCallbacks).forEach((callback) => callback()); }); __privateAdd(this, _callErrorCallbacks, () => { __privateGet(this, _errorCallbacks).forEach((callback) => callback()); }); __privateSet(this, _messages, initialMessages); } get status() { return __privateGet(this, _status); } set status(newStatus) { __privateSet(this, _status, newStatus); __privateGet(this, _callStatusCallbacks).call(this); } get error() { return __privateGet(this, _error); } set error(newError) { __privateSet(this, _error, newError); __privateGet(this, _callErrorCallbacks).call(this); } get messages() { return __privateGet(this, _messages); } set messages(newMessages) { __privateSet(this, _messages, [...newMessages]); __privateGet(this, _callMessagesCallbacks).call(this); } }; _messages = new WeakMap(); _status = new WeakMap(); _error = new WeakMap(); _messagesCallbacks = new WeakMap(); _statusCallbacks = new WeakMap(); _errorCallbacks = new WeakMap(); _callMessagesCallbacks = new WeakMap(); _callStatusCallbacks = new WeakMap(); _callErrorCallbacks = new WeakMap(); var _state; var Chat = class extends AbstractChat { constructor({ messages, ...init }) { const state = new ReactChatState(messages); super({ ...init, state }); __privateAdd(this, _state, void 0); this["~registerMessagesCallback"] = (onChange, throttleWaitMs) => __privateGet(this, _state)["~registerMessagesCallback"](onChange, throttleWaitMs); this["~registerStatusCallback"] = (onChange) => __privateGet(this, _state)["~registerStatusCallback"](onChange); this["~registerErrorCallback"] = (onChange) => __privateGet(this, _state)["~registerErrorCallback"](onChange); __privateSet(this, _state, state); } }; _state = new WeakMap(); // src/use-chat.ts function useChat({ experimental_throttle: throttleWaitMs, resume = false, ...options } = {}) { const chatRef = useRef( "chat" in options ? options.chat : new Chat(options) ); const shouldRecreateChat = "chat" in options && options.chat !== chatRef.current || "id" in options && chatRef.current.id !== options.id; if (shouldRecreateChat) { chatRef.current = "chat" in options ? options.chat : new Chat(options); } const optionsId = "id" in options ? options.id : null; const subscribeToMessages = useCallback( (update) => chatRef.current["~registerMessagesCallback"](update, throttleWaitMs), // optionsId is required to trigger re-subscription when the chat ID changes // eslint-disable-next-line react-hooks/exhaustive-deps [throttleWaitMs, optionsId] ); const messages = useSyncExternalStore( subscribeToMessages, () => chatRef.current.messages, () => chatRef.current.messages ); const status = useSyncExternalStore( chatRef.current["~registerStatusCallback"], () => chatRef.current.status, () => chatRef.current.status ); const error = useSyncExternalStore( chatRef.current["~registerErrorCallback"], () => chatRef.current.error, () => chatRef.current.error ); const setMessages = useCallback( (messagesParam) => { if (typeof messagesParam === "function") { messagesParam = messagesParam(chatRef.current.messages); } chatRef.current.messages = messagesParam; }, [chatRef] ); useEffect(() => { if (resume) { chatRef.current.resumeStream(); } }, [resume, chatRef]); return { id: chatRef.current.id, messages, setMessages, sendMessage: chatRef.current.sendMessage, regenerate: chatRef.current.regenerate, clearError: chatRef.current.clearError, stop: chatRef.current.stop, error, resumeStream: chatRef.current.resumeStream, status, /** * @deprecated Use `addToolOutput` instead. */ addToolResult: chatRef.current.addToolOutput, addToolOutput: chatRef.current.addToolOutput }; } // src/use-completion.ts import { callCompletionApi } from "ai"; import { useCallback as useCallback2, useEffect as useEffect2, useId, useRef as useRef2, useState } from "react"; import useSWR from "swr"; function useCompletion({ api = "/api/completion", id, initialCompletion = "", initialInput = "", credentials, headers, body, streamProtocol = "data", fetch: fetch2, onFinish, onError, experimental_throttle: throttleWaitMs } = {}) { const hookId = useId(); const completionId = id || hookId; const { data, mutate } = useSWR([api, completionId], null, { fallbackData: initialCompletion }); const { data: isLoading = false, mutate: mutateLoading } = useSWR( [completionId, "loading"], null ); const [error, setError] = useState(void 0); const completion = data; const [abortController, setAbortController] = useState(null); const extraMetadataRef = useRef2({ credentials, headers, body }); useEffect2(() => { extraMetadataRef.current = { credentials, headers, body }; }, [credentials, headers, body]); const triggerRequest = useCallback2( async (prompt, options) => callCompletionApi({ api, prompt, credentials: extraMetadataRef.current.credentials, headers: { ...extraMetadataRef.current.headers, ...options == null ? void 0 : options.headers }, body: { ...extraMetadataRef.current.body, ...options == null ? void 0 : options.body }, streamProtocol, fetch: fetch2, // throttle streamed ui updates: setCompletion: throttle( (completion2) => mutate(completion2, false), throttleWaitMs ), setLoading: mutateLoading, setError, setAbortController, onFinish, onError }), [ mutate, mutateLoading, api, extraMetadataRef, setAbortController, onFinish, onError, setError, streamProtocol, fetch2, throttleWaitMs ] ); const stop = useCallback2(() => { if (abortController) { abortController.abort(); setAbortController(null); } }, [abortController]); const setCompletion = useCallback2( (completion2) => { mutate(completion2, false); }, [mutate] ); const complete = useCallback2( async (prompt, options) => { return triggerRequest(prompt, options); }, [triggerRequest] ); const [input, setInput] = useState(initialInput); const handleSubmit = useCallback2( (event) => { var _a; (_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event); return input ? complete(input) : void 0; }, [input, complete] ); const handleInputChange = useCallback2( (e) => { setInput(e.target.value); }, [setInput] ); return { completion, complete, error, setCompletion, stop, input, setInput, handleInputChange, handleSubmit, isLoading }; } // src/use-object.ts import { isAbortError, safeValidateTypes } from "@ai-sdk/provider-utils"; import { asSchema, isDeepEqualData, parsePartialJson } from "ai"; import { useCallback as useCallback3, useId as useId2, useRef as useRef3, useState as useState2 } from "react"; import useSWR2 from "swr"; var getOriginalFetch = () => fetch; function useObject({ api, id, schema, // required, in the future we will use it for validation initialValue, fetch: fetch2, onError, onFinish, headers, credentials }) { const hookId = useId2(); const completionId = id != null ? id : hookId; const { data, mutate } = useSWR2( [api, completionId], null, { fallbackData: initialValue } ); const [error, setError] = useState2(void 0); const [isLoading, setIsLoading] = useState2(false); const abortControllerRef = useRef3(null); const stop = useCallback3(() => { var _a; try { (_a = abortControllerRef.current) == null ? void 0 : _a.abort(); } catch (ignored) { } finally { setIsLoading(false); abortControllerRef.current = null; } }, []); const submit = async (input) => { var _a; try { clearObject(); setIsLoading(true); const abortController = new AbortController(); abortControllerRef.current = abortController; const actualFetch = fetch2 != null ? fetch2 : getOriginalFetch(); const response = await actualFetch(api, { method: "POST", headers: { "Content-Type": "application/json", ...headers }, credentials, signal: abortController.signal, body: JSON.stringify(input) }); if (!response.ok) { throw new Error( (_a = await response.text()) != null ? _a : "Failed to fetch the response." ); } if (response.body == null) { throw new Error("The response body is empty."); } let accumulatedText = ""; let latestObject = void 0; await response.body.pipeThrough(new TextDecoderStream()).pipeTo( new WritableStream({ async write(chunk) { accumulatedText += chunk; const { value } = await parsePartialJson(accumulatedText); const currentObject = value; if (!isDeepEqualData(latestObject, currentObject)) { latestObject = currentObject; mutate(currentObject); } }, async close() { setIsLoading(false); abortControllerRef.current = null; if (onFinish != null) { const validationResult = await safeValidateTypes({ value: latestObject, schema: asSchema(schema) }); onFinish( validationResult.success ? { object: validationResult.value, error: void 0 } : { object: void 0, error: validationResult.error } ); } } }) ); } catch (error2) { if (isAbortError(error2)) { return; } if (onError && error2 instanceof Error) { onError(error2); } setIsLoading(false); setError(error2 instanceof Error ? error2 : new Error(String(error2))); } }; const clear = () => { stop(); clearObject(); }; const clearObject = () => { setError(void 0); setIsLoading(false); mutate(void 0); }; return { submit, object: data, error, isLoading, stop, clear }; } var experimental_useObject = useObject; export { Chat, experimental_useObject, useChat, useCompletion }; //# sourceMappingURL=index.mjs.map