UNPKG

@ai-sdk/solid

Version:

> **Warning** `@ai-sdk/solid` has been deprecated and will be removed in AI SDK 5

848 lines (842 loc) 24.7 kB
// src/use-chat.ts import { callChatApi, extractMaxToolInvocationStep, fillMessageParts, generateId as generateIdFunc, getMessageParts, isAssistantMessageWithCompletedToolCalls, prepareAttachmentsForRequest, shouldResubmitMessages, updateToolCallResult } from "@ai-sdk/ui-utils"; import { createEffect, createMemo as createMemo2, createSignal } from "solid-js"; import { createStore, reconcile } from "solid-js/store"; // src/utils/convert-to-accessor-options.ts import { createMemo } from "solid-js"; function convertToAccessorOptions(options) { const resolvedOptions = typeof options === "function" ? options() : options; return Object.entries(resolvedOptions).reduce( (reactiveOptions, [key, value]) => { reactiveOptions[key] = createMemo(() => value); return reactiveOptions; }, {} ); } // src/utils/reactive-lru.ts import { batch } from "solid-js"; import { TriggerCache } from "@solid-primitives/trigger"; var $KEYS = Symbol("track-keys"); var ReactiveLRU = class extends Map { #keyTriggers = new TriggerCache(); #valueTriggers = new TriggerCache(); #maxSize; #accessList = []; constructor(maxSize = 10, initial) { super(); this.#maxSize = maxSize; if (initial) { for (const [key, value] of initial) { this.set(key, value); } } } #recordAccess(key) { const index = this.#accessList.indexOf(key); if (index > -1) { this.#accessList.splice(index, 1); } this.#accessList.push(key); if (this.#accessList.length > this.#maxSize) { const lru = this.#accessList.shift(); this.delete(lru); } } // reads has(key) { this.#keyTriggers.track(key); const exists = super.has(key); if (exists) { this.#recordAccess(key); } return exists; } get(key) { this.#valueTriggers.track(key); const value = super.get(key); if (value !== void 0) { this.#recordAccess(key); } return value; } get size() { this.#keyTriggers.track($KEYS); return super.size; } *keys() { for (const key of super.keys()) { this.#keyTriggers.track(key); yield key; } this.#keyTriggers.track($KEYS); } *values() { for (const [key, v] of super.entries()) { this.#valueTriggers.track(key); yield v; } this.#keyTriggers.track($KEYS); } *entries() { for (const entry of super.entries()) { this.#valueTriggers.track(entry[0]); yield entry; } this.#keyTriggers.track($KEYS); } // writes set(key, value) { batch(() => { if (super.has(key)) { if (super.get(key) === value) { this.#recordAccess(key); return; } } else { this.#keyTriggers.dirty(key); this.#keyTriggers.dirty($KEYS); } this.#valueTriggers.dirty(key); super.set(key, value); this.#recordAccess(key); }); return this; } delete(key) { const r = super.delete(key); if (r) { batch(() => { this.#keyTriggers.dirty(key); this.#keyTriggers.dirty($KEYS); this.#valueTriggers.dirty(key); const index = this.#accessList.indexOf(key); if (index > -1) { this.#accessList.splice(index, 1); } }); } return r; } clear() { if (super.size) { batch(() => { for (const v of super.keys()) { this.#keyTriggers.dirty(v); this.#valueTriggers.dirty(v); } super.clear(); this.#accessList = []; this.#keyTriggers.dirty($KEYS); }); } } // callback forEach(callbackfn) { this.#keyTriggers.track($KEYS); for (const [key, v] of super.entries()) { this.#valueTriggers.track(key); this.#recordAccess(key); callbackfn(v, key, this); } } [Symbol.iterator]() { return this.entries(); } }; // src/use-chat.ts var chatCache = new ReactiveLRU(); function useChat(rawUseChatOptions = {}) { const resolvedOptions = createMemo2( () => convertToAccessorOptions(rawUseChatOptions) ); const prepareFn = createMemo2(() => { const opts = resolvedOptions(); return opts.experimental_prepareRequestBody?.(); }); const useChatOptions = createMemo2(() => ({ ...resolvedOptions(), experimental_prepareRequestBody: prepareFn })); const api = createMemo2(() => useChatOptions().api?.() ?? "/api/chat"); const generateId2 = createMemo2( () => useChatOptions().generateId?.() ?? generateIdFunc ); const chatId = createMemo2(() => useChatOptions().id?.() ?? generateId2()()); const chatKey = createMemo2(() => `${api()}|${chatId()}|messages`); const _messages = createMemo2( () => chatCache.get(chatKey()) ?? useChatOptions().initialMessages?.() ?? [] ); const [messagesStore, setMessagesStore] = createStore( fillMessageParts(_messages()) ); createEffect(() => { setMessagesStore(reconcile(fillMessageParts(_messages()), { merge: true })); }); const mutate = (messages) => { chatCache.set(chatKey(), messages); }; const [error, setError] = createSignal(void 0); const [streamData, setStreamData] = createSignal( void 0 ); const [status, setStatus] = createSignal("ready"); let messagesRef = fillMessageParts(_messages()) || []; createEffect(() => { messagesRef = fillMessageParts(_messages()) || []; }); let abortController = null; let extraMetadata = { credentials: useChatOptions().credentials?.(), headers: useChatOptions().headers?.(), body: useChatOptions().body?.() }; createEffect(() => { extraMetadata = { credentials: useChatOptions().credentials?.(), headers: useChatOptions().headers?.(), body: useChatOptions().body?.() }; }); const triggerRequest = async (chatRequest) => { setError(void 0); setStatus("submitted"); const messageCount = messagesRef.length; const maxStep = extractMaxToolInvocationStep( chatRequest.messages[chatRequest.messages.length - 1]?.toolInvocations ); try { abortController = new AbortController(); const streamProtocol = useChatOptions().streamProtocol?.() ?? "data"; const onFinish = useChatOptions().onFinish?.(); const onResponse = useChatOptions().onResponse?.(); const onToolCall = useChatOptions().onToolCall?.(); const sendExtraMessageFields = useChatOptions().sendExtraMessageFields?.(); const keepLastMessageOnError = useChatOptions().keepLastMessageOnError?.() ?? true; const experimental_prepareRequestBody = useChatOptions().experimental_prepareRequestBody?.(); const previousMessages = messagesRef; const chatMessages = fillMessageParts(chatRequest.messages); mutate(chatMessages); const existingStreamData = streamData() ?? []; const constructedMessagesPayload = sendExtraMessageFields ? chatMessages : chatMessages.map( ({ role, content, experimental_attachments, data, annotations, toolInvocations, parts }) => ({ role, content, ...experimental_attachments !== void 0 && { experimental_attachments }, ...data !== void 0 && { data }, ...annotations !== void 0 && { annotations }, ...toolInvocations !== void 0 && { toolInvocations }, ...parts !== void 0 && { parts } }) ); await callChatApi({ api: api(), body: experimental_prepareRequestBody?.({ id: chatId(), messages: chatMessages, requestData: chatRequest.data, requestBody: chatRequest.body }) ?? { id: chatId(), messages: constructedMessagesPayload, data: chatRequest.data, ...extraMetadata.body, ...chatRequest.body }, streamProtocol, credentials: extraMetadata.credentials, headers: { ...extraMetadata.headers, ...chatRequest.headers }, abortController: () => abortController, restoreMessagesOnFailure() { if (!keepLastMessageOnError) { mutate(previousMessages); } }, onResponse, onUpdate({ message, data, replaceLastMessage }) { setStatus("streaming"); mutate([ ...replaceLastMessage ? chatMessages.slice(0, chatMessages.length - 1) : chatMessages, message ]); if (data?.length) { setStreamData([...existingStreamData, ...data]); } }, onToolCall, onFinish, generateId: generateId2(), fetch: useChatOptions().fetch?.(), lastMessage: chatMessages[chatMessages.length - 1] }); abortController = null; setStatus("ready"); } catch (err) { if (err.name === "AbortError") { abortController = null; setStatus("ready"); return null; } const onError = useChatOptions().onError?.(); if (onError && err instanceof Error) { onError(err); } setError(err); setStatus("error"); } const maxSteps = useChatOptions().maxSteps?.() ?? 1; const messages = messagesRef; if (shouldResubmitMessages({ originalMaxToolInvocationStep: maxStep, originalMessageCount: messageCount, maxSteps, messages })) { await triggerRequest({ messages }); } }; const append = async (message, { data, headers, body, experimental_attachments = message.experimental_attachments } = {}) => { const attachmentsForRequest = await prepareAttachmentsForRequest( experimental_attachments ); const messages = messagesRef.concat({ ...message, id: message.id ?? generateId2()(), createdAt: message.createdAt ?? /* @__PURE__ */ new Date(), experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0, parts: getMessageParts(message) }); return triggerRequest({ messages, headers, body, data }); }; const reload = async ({ data, headers, body } = {}) => { if (messagesRef.length === 0) { return null; } const lastMessage = messagesRef[messagesRef.length - 1]; return triggerRequest({ messages: lastMessage.role === "assistant" ? messagesRef.slice(0, -1) : messagesRef, headers, body, data }); }; const stop = () => { if (abortController) { abortController.abort(); abortController = null; } }; const setMessages = (messagesArg) => { if (typeof messagesArg === "function") { messagesArg = messagesArg(messagesRef); } const messagesWithParts = fillMessageParts(messagesArg); mutate(messagesWithParts); messagesRef = messagesWithParts; }; const setData = (dataArg) => { if (typeof dataArg === "function") { dataArg = dataArg(streamData()); } setStreamData(dataArg); }; const [input, setInput] = createSignal( useChatOptions().initialInput?.() || "" ); const handleSubmit = async (event, options = {}, metadata) => { event?.preventDefault?.(); const inputValue = input(); if (!inputValue && !options.allowEmptySubmit) return; const attachmentsForRequest = await prepareAttachmentsForRequest( options.experimental_attachments ); if (metadata) { extraMetadata = { ...extraMetadata, ...metadata }; } triggerRequest({ messages: messagesRef.concat({ id: generateId2()(), role: "user", content: inputValue, createdAt: /* @__PURE__ */ new Date(), experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0, parts: [{ type: "text", text: inputValue }] }), headers: options.headers, body: options.body, data: options.data }); setInput(""); }; const handleInputChange = (e) => { setInput(e.target.value); }; const addToolResult = ({ toolCallId, result }) => { const currentMessages = messagesRef ?? []; updateToolCallResult({ messages: currentMessages, toolCallId, toolResult: result }); mutate(currentMessages); if (status() === "submitted" || status() === "streaming") { return; } const lastMessage = currentMessages[currentMessages.length - 1]; if (isAssistantMessageWithCompletedToolCalls(lastMessage)) { triggerRequest({ messages: currentMessages }); } }; const isLoading = createMemo2( () => status() === "submitted" || status() === "streaming" ); return { // TODO next major release: replace with direct message store access (breaking change) messages: () => messagesStore, id: chatId(), append, error, reload, stop, setMessages, input, setInput, handleInputChange, handleSubmit, isLoading, status, data: streamData, setData, addToolResult }; } // src/use-completion.ts import { callCompletionApi } from "@ai-sdk/ui-utils"; import { createEffect as createEffect2, createMemo as createMemo3, createSignal as createSignal2, createUniqueId } from "solid-js"; var completionCache = new ReactiveLRU(); function useCompletion(rawUseCompletionOptions = {}) { const useCompletionOptions = createMemo3( () => convertToAccessorOptions(rawUseCompletionOptions) ); const api = createMemo3( () => useCompletionOptions().api?.() ?? "/api/completion" ); const idKey = createMemo3( () => useCompletionOptions().id?.() ?? `completion-${createUniqueId()}` ); const completionKey = createMemo3(() => `${api()}|${idKey()}|completion`); const completion = createMemo3( () => completionCache.get(completionKey()) ?? useCompletionOptions().initialCompletion?.() ?? "" ); const mutate = (data) => { completionCache.set(completionKey(), data); }; const [error, setError] = createSignal2(void 0); const [streamData, setStreamData] = createSignal2( void 0 ); const [isLoading, setIsLoading] = createSignal2(false); const [abortController, setAbortController] = createSignal2(null); let extraMetadata = { credentials: useCompletionOptions().credentials?.(), headers: useCompletionOptions().headers?.(), body: useCompletionOptions().body?.() }; createEffect2(() => { extraMetadata = { credentials: useCompletionOptions().credentials?.(), headers: useCompletionOptions().headers?.(), body: useCompletionOptions().body?.() }; }); const complete = async (prompt, options) => { const existingData = streamData() ?? []; return callCompletionApi({ api: api(), prompt, credentials: useCompletionOptions().credentials?.(), headers: { ...extraMetadata.headers, ...options?.headers }, body: { ...extraMetadata.body, ...options?.body }, streamProtocol: useCompletionOptions().streamProtocol?.(), setCompletion: mutate, setLoading: setIsLoading, setError, setAbortController, onResponse: useCompletionOptions().onResponse?.(), onFinish: useCompletionOptions().onFinish?.(), onError: useCompletionOptions().onError?.(), onData: (data) => { setStreamData([...existingData, ...data ?? []]); }, fetch: useCompletionOptions().fetch?.() }); }; const stop = () => { if (abortController()) { abortController().abort(); } }; const setCompletion = (completion2) => { mutate(completion2); }; const [input, setInput] = createSignal2( useCompletionOptions().initialInput?.() ?? "" ); const handleInputChange = (event) => { setInput(event.target.value); }; const handleSubmit = (event) => { event?.preventDefault?.(); const inputValue = input(); return inputValue ? complete(inputValue) : void 0; }; return { completion, complete, error, stop, setCompletion, input, setInput, handleInputChange, handleSubmit, isLoading, data: streamData }; } // src/use-object.ts import { isAbortError, safeValidateTypes } from "@ai-sdk/provider-utils"; import { asSchema, isDeepEqualData, parsePartialJson } from "@ai-sdk/ui-utils"; import { createMemo as createMemo4, createSignal as createSignal3, createUniqueId as createUniqueId2 } from "solid-js"; var getOriginalFetch = () => fetch; var objectCache = new ReactiveLRU(); function useObject(rawUseObjectOptions) { const useObjectOptions = createMemo4( () => convertToAccessorOptions(rawUseObjectOptions) ); const idKey = createMemo4( () => useObjectOptions().id?.() ?? `object-${createUniqueId2()}` ); const data = createMemo4( () => objectCache.get(idKey()) ?? useObjectOptions().initialValue?.() ); const mutate = (value) => { objectCache.set(idKey(), value); }; const [error, setError] = createSignal3(); const [isLoading, setIsLoading] = createSignal3(false); let abortControllerRef = null; const stop = () => { try { abortControllerRef?.abort(); } catch (ignored) { } finally { setIsLoading(false); abortControllerRef = null; } }; const submit = async (input) => { try { mutate(void 0); setIsLoading(true); setError(void 0); const abortController = new AbortController(); abortControllerRef = abortController; const actualFetch = fetch ?? getOriginalFetch(); const response = await actualFetch(useObjectOptions().api(), { method: "POST", headers: { "Content-Type": "application/json", ...useObjectOptions().headers?.() }, credentials: useObjectOptions().credentials?.(), signal: abortController.signal, body: JSON.stringify(input) }); if (!response.ok) { throw new Error( await response.text() ?? "Failed to fetch the response." ); } if (response.body == null) { throw new Error("The response body is empty."); } let accumulatedText = ""; let latestObject = void 0; await response.body.pipeThrough(new TextDecoderStream()).pipeTo( new WritableStream({ write(chunk) { accumulatedText += chunk; const { value } = parsePartialJson(accumulatedText); const currentObject = value; if (!isDeepEqualData(latestObject, currentObject)) { latestObject = currentObject; mutate(currentObject); } }, close() { setIsLoading(false); abortControllerRef = null; const onFinish = useObjectOptions().onFinish?.(); if (onFinish != null) { const validationResult = safeValidateTypes({ value: latestObject, schema: asSchema(useObjectOptions().schema()) }); onFinish( validationResult.success ? { object: validationResult.value, error: void 0 } : { object: void 0, error: validationResult.error } ); } } }) ); } catch (error2) { if (isAbortError(error2)) { return; } const onError = useObjectOptions().onError?.(); if (onError && error2 instanceof Error) { onError(error2); } setIsLoading(false); setError(error2 instanceof Error ? error2 : new Error(String(error2))); } }; return { submit, object: data, error, isLoading, stop }; } var experimental_useObject = useObject; // src/use-assistant.ts import { isAbortError as isAbortError2 } from "@ai-sdk/provider-utils"; import { generateId, processAssistantStream } from "@ai-sdk/ui-utils"; import { createMemo as createMemo5, createSignal as createSignal4 } from "solid-js"; import { createStore as createStore2 } from "solid-js/store"; var getOriginalFetch2 = () => fetch; function useAssistant(rawUseAssistantOptions) { const useAssistantOptions = createMemo5( () => convertToAccessorOptions(rawUseAssistantOptions) ); const [messages, setMessages] = createStore2([]); const [input, setInput] = createSignal4(""); const [currentThreadId, setCurrentThreadId] = createSignal4(); const [status, setStatus] = createSignal4("awaiting_message"); const [error, setError] = createSignal4(); const handleInputChange = (event) => { setInput(event.target.value); }; let abortControllerRef = null; const stop = () => { if (abortControllerRef) { abortControllerRef?.abort(); abortControllerRef = null; } }; const append = async (message, requestOptions) => { setStatus("in_progress"); setMessages((messages2) => [ ...messages2, { ...message, id: message.id ?? generateId() } ]); setInput(""); const abortController = new AbortController(); try { abortControllerRef = abortController; const actualFetch = fetch ?? getOriginalFetch2(); const response = await actualFetch(useAssistantOptions().api(), { method: "POST", credentials: useAssistantOptions().credentials?.(), signal: abortController.signal, headers: { "Content-Type": "application/json", ...useAssistantOptions().headers?.() }, body: JSON.stringify({ ...useAssistantOptions().body?.(), // always use user-provided threadId when available: threadId: useAssistantOptions().threadId?.() ?? currentThreadId(), message: message.content, // optional request data: data: requestOptions?.data }) }); if (!response.ok) { throw new Error( await response.text() ?? "Failed to fetch the assistant response." ); } if (response.body == null) { throw new Error("The response body is empty."); } await processAssistantStream({ stream: response.body, onAssistantMessagePart(value) { setMessages((messages2) => [ ...messages2, { id: value.id, role: value.role, content: value.content[0].text.value, parts: [] } ]); }, onTextPart(value) { setMessages((messages2) => { const lastMessage = messages2[messages2.length - 1]; return [ ...messages2.slice(0, messages2.length - 1), { id: lastMessage.id, role: lastMessage.role, content: lastMessage.content + value, parts: lastMessage.parts } ]; }); }, onAssistantControlDataPart(value) { setCurrentThreadId(value.threadId); setMessages((messages2) => { const lastMessage = messages2[messages2.length - 1]; lastMessage.id = value.messageId; return [...messages2.slice(0, messages2.length - 1), lastMessage]; }); }, onDataMessagePart(value) { setMessages((messages2) => [ ...messages2, { id: value.id ?? generateId(), role: "data", content: "", data: value.data, parts: [] } ]); }, onErrorPart(value) { setError(new Error(value)); } }); } catch (error2) { if (isAbortError2(error2) && abortController.signal.aborted) { abortControllerRef = null; return; } const onError = useAssistantOptions().onError?.(); if (onError && error2 instanceof Error) { onError(error2); } setError(error2); } finally { abortControllerRef = null; setStatus("awaiting_message"); } }; const submitMessage = async (event, requestOptions) => { event?.preventDefault?.(); if (input() === "") { return; } append({ role: "user", content: input(), parts: [] }, requestOptions); }; const setThreadId = (threadId) => { setCurrentThreadId(threadId); setMessages([]); }; return { append, messages, setMessages, threadId: currentThreadId, setThreadId, input, setInput, handleInputChange, submitMessage, status, error, stop }; } export { experimental_useObject, useAssistant, useChat, useCompletion }; //# sourceMappingURL=index.mjs.map