@ai-sdk/solid
Version:
> **Warning** `@ai-sdk/solid` has been deprecated and will be removed in AI SDK 5
848 lines (842 loc) • 24.7 kB
JavaScript
// src/use-chat.ts
import {
callChatApi,
extractMaxToolInvocationStep,
fillMessageParts,
generateId as generateIdFunc,
getMessageParts,
isAssistantMessageWithCompletedToolCalls,
prepareAttachmentsForRequest,
shouldResubmitMessages,
updateToolCallResult
} from "@ai-sdk/ui-utils";
import {
createEffect,
createMemo as createMemo2,
createSignal
} from "solid-js";
import { createStore, reconcile } from "solid-js/store";
// src/utils/convert-to-accessor-options.ts
import { createMemo } from "solid-js";
function convertToAccessorOptions(options) {
const resolvedOptions = typeof options === "function" ? options() : options;
return Object.entries(resolvedOptions).reduce(
(reactiveOptions, [key, value]) => {
reactiveOptions[key] = createMemo(() => value);
return reactiveOptions;
},
{}
);
}
// src/utils/reactive-lru.ts
import { batch } from "solid-js";
import { TriggerCache } from "@solid-primitives/trigger";
var $KEYS = Symbol("track-keys");
var ReactiveLRU = class extends Map {
#keyTriggers = new TriggerCache();
#valueTriggers = new TriggerCache();
#maxSize;
#accessList = [];
constructor(maxSize = 10, initial) {
super();
this.#maxSize = maxSize;
if (initial) {
for (const [key, value] of initial) {
this.set(key, value);
}
}
}
#recordAccess(key) {
const index = this.#accessList.indexOf(key);
if (index > -1) {
this.#accessList.splice(index, 1);
}
this.#accessList.push(key);
if (this.#accessList.length > this.#maxSize) {
const lru = this.#accessList.shift();
this.delete(lru);
}
}
// reads
has(key) {
this.#keyTriggers.track(key);
const exists = super.has(key);
if (exists) {
this.#recordAccess(key);
}
return exists;
}
get(key) {
this.#valueTriggers.track(key);
const value = super.get(key);
if (value !== void 0) {
this.#recordAccess(key);
}
return value;
}
get size() {
this.#keyTriggers.track($KEYS);
return super.size;
}
*keys() {
for (const key of super.keys()) {
this.#keyTriggers.track(key);
yield key;
}
this.#keyTriggers.track($KEYS);
}
*values() {
for (const [key, v] of super.entries()) {
this.#valueTriggers.track(key);
yield v;
}
this.#keyTriggers.track($KEYS);
}
*entries() {
for (const entry of super.entries()) {
this.#valueTriggers.track(entry[0]);
yield entry;
}
this.#keyTriggers.track($KEYS);
}
// writes
set(key, value) {
batch(() => {
if (super.has(key)) {
if (super.get(key) === value) {
this.#recordAccess(key);
return;
}
} else {
this.#keyTriggers.dirty(key);
this.#keyTriggers.dirty($KEYS);
}
this.#valueTriggers.dirty(key);
super.set(key, value);
this.#recordAccess(key);
});
return this;
}
delete(key) {
const r = super.delete(key);
if (r) {
batch(() => {
this.#keyTriggers.dirty(key);
this.#keyTriggers.dirty($KEYS);
this.#valueTriggers.dirty(key);
const index = this.#accessList.indexOf(key);
if (index > -1) {
this.#accessList.splice(index, 1);
}
});
}
return r;
}
clear() {
if (super.size) {
batch(() => {
for (const v of super.keys()) {
this.#keyTriggers.dirty(v);
this.#valueTriggers.dirty(v);
}
super.clear();
this.#accessList = [];
this.#keyTriggers.dirty($KEYS);
});
}
}
// callback
forEach(callbackfn) {
this.#keyTriggers.track($KEYS);
for (const [key, v] of super.entries()) {
this.#valueTriggers.track(key);
this.#recordAccess(key);
callbackfn(v, key, this);
}
}
[Symbol.iterator]() {
return this.entries();
}
};
// src/use-chat.ts
var chatCache = new ReactiveLRU();
function useChat(rawUseChatOptions = {}) {
const resolvedOptions = createMemo2(
() => convertToAccessorOptions(rawUseChatOptions)
);
const prepareFn = createMemo2(() => {
const opts = resolvedOptions();
return opts.experimental_prepareRequestBody?.();
});
const useChatOptions = createMemo2(() => ({
...resolvedOptions(),
experimental_prepareRequestBody: prepareFn
}));
const api = createMemo2(() => useChatOptions().api?.() ?? "/api/chat");
const generateId2 = createMemo2(
() => useChatOptions().generateId?.() ?? generateIdFunc
);
const chatId = createMemo2(() => useChatOptions().id?.() ?? generateId2()());
const chatKey = createMemo2(() => `${api()}|${chatId()}|messages`);
const _messages = createMemo2(
() => chatCache.get(chatKey()) ?? useChatOptions().initialMessages?.() ?? []
);
const [messagesStore, setMessagesStore] = createStore(
fillMessageParts(_messages())
);
createEffect(() => {
setMessagesStore(reconcile(fillMessageParts(_messages()), { merge: true }));
});
const mutate = (messages) => {
chatCache.set(chatKey(), messages);
};
const [error, setError] = createSignal(void 0);
const [streamData, setStreamData] = createSignal(
void 0
);
const [status, setStatus] = createSignal("ready");
let messagesRef = fillMessageParts(_messages()) || [];
createEffect(() => {
messagesRef = fillMessageParts(_messages()) || [];
});
let abortController = null;
let extraMetadata = {
credentials: useChatOptions().credentials?.(),
headers: useChatOptions().headers?.(),
body: useChatOptions().body?.()
};
createEffect(() => {
extraMetadata = {
credentials: useChatOptions().credentials?.(),
headers: useChatOptions().headers?.(),
body: useChatOptions().body?.()
};
});
const triggerRequest = async (chatRequest) => {
setError(void 0);
setStatus("submitted");
const messageCount = messagesRef.length;
const maxStep = extractMaxToolInvocationStep(
chatRequest.messages[chatRequest.messages.length - 1]?.toolInvocations
);
try {
abortController = new AbortController();
const streamProtocol = useChatOptions().streamProtocol?.() ?? "data";
const onFinish = useChatOptions().onFinish?.();
const onResponse = useChatOptions().onResponse?.();
const onToolCall = useChatOptions().onToolCall?.();
const sendExtraMessageFields = useChatOptions().sendExtraMessageFields?.();
const keepLastMessageOnError = useChatOptions().keepLastMessageOnError?.() ?? true;
const experimental_prepareRequestBody = useChatOptions().experimental_prepareRequestBody?.();
const previousMessages = messagesRef;
const chatMessages = fillMessageParts(chatRequest.messages);
mutate(chatMessages);
const existingStreamData = streamData() ?? [];
const constructedMessagesPayload = sendExtraMessageFields ? chatMessages : chatMessages.map(
({
role,
content,
experimental_attachments,
data,
annotations,
toolInvocations,
parts
}) => ({
role,
content,
...experimental_attachments !== void 0 && {
experimental_attachments
},
...data !== void 0 && { data },
...annotations !== void 0 && { annotations },
...toolInvocations !== void 0 && { toolInvocations },
...parts !== void 0 && { parts }
})
);
await callChatApi({
api: api(),
body: experimental_prepareRequestBody?.({
id: chatId(),
messages: chatMessages,
requestData: chatRequest.data,
requestBody: chatRequest.body
}) ?? {
id: chatId(),
messages: constructedMessagesPayload,
data: chatRequest.data,
...extraMetadata.body,
...chatRequest.body
},
streamProtocol,
credentials: extraMetadata.credentials,
headers: {
...extraMetadata.headers,
...chatRequest.headers
},
abortController: () => abortController,
restoreMessagesOnFailure() {
if (!keepLastMessageOnError) {
mutate(previousMessages);
}
},
onResponse,
onUpdate({ message, data, replaceLastMessage }) {
setStatus("streaming");
mutate([
...replaceLastMessage ? chatMessages.slice(0, chatMessages.length - 1) : chatMessages,
message
]);
if (data?.length) {
setStreamData([...existingStreamData, ...data]);
}
},
onToolCall,
onFinish,
generateId: generateId2(),
fetch: useChatOptions().fetch?.(),
lastMessage: chatMessages[chatMessages.length - 1]
});
abortController = null;
setStatus("ready");
} catch (err) {
if (err.name === "AbortError") {
abortController = null;
setStatus("ready");
return null;
}
const onError = useChatOptions().onError?.();
if (onError && err instanceof Error) {
onError(err);
}
setError(err);
setStatus("error");
}
const maxSteps = useChatOptions().maxSteps?.() ?? 1;
const messages = messagesRef;
if (shouldResubmitMessages({
originalMaxToolInvocationStep: maxStep,
originalMessageCount: messageCount,
maxSteps,
messages
})) {
await triggerRequest({ messages });
}
};
const append = async (message, {
data,
headers,
body,
experimental_attachments = message.experimental_attachments
} = {}) => {
const attachmentsForRequest = await prepareAttachmentsForRequest(
experimental_attachments
);
const messages = messagesRef.concat({
...message,
id: message.id ?? generateId2()(),
createdAt: message.createdAt ?? /* @__PURE__ */ new Date(),
experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0,
parts: getMessageParts(message)
});
return triggerRequest({
messages,
headers,
body,
data
});
};
const reload = async ({
data,
headers,
body
} = {}) => {
if (messagesRef.length === 0) {
return null;
}
const lastMessage = messagesRef[messagesRef.length - 1];
return triggerRequest({
messages: lastMessage.role === "assistant" ? messagesRef.slice(0, -1) : messagesRef,
headers,
body,
data
});
};
const stop = () => {
if (abortController) {
abortController.abort();
abortController = null;
}
};
const setMessages = (messagesArg) => {
if (typeof messagesArg === "function") {
messagesArg = messagesArg(messagesRef);
}
const messagesWithParts = fillMessageParts(messagesArg);
mutate(messagesWithParts);
messagesRef = messagesWithParts;
};
const setData = (dataArg) => {
if (typeof dataArg === "function") {
dataArg = dataArg(streamData());
}
setStreamData(dataArg);
};
const [input, setInput] = createSignal(
useChatOptions().initialInput?.() || ""
);
const handleSubmit = async (event, options = {}, metadata) => {
event?.preventDefault?.();
const inputValue = input();
if (!inputValue && !options.allowEmptySubmit)
return;
const attachmentsForRequest = await prepareAttachmentsForRequest(
options.experimental_attachments
);
if (metadata) {
extraMetadata = {
...extraMetadata,
...metadata
};
}
triggerRequest({
messages: messagesRef.concat({
id: generateId2()(),
role: "user",
content: inputValue,
createdAt: /* @__PURE__ */ new Date(),
experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0,
parts: [{ type: "text", text: inputValue }]
}),
headers: options.headers,
body: options.body,
data: options.data
});
setInput("");
};
const handleInputChange = (e) => {
setInput(e.target.value);
};
const addToolResult = ({
toolCallId,
result
}) => {
const currentMessages = messagesRef ?? [];
updateToolCallResult({
messages: currentMessages,
toolCallId,
toolResult: result
});
mutate(currentMessages);
if (status() === "submitted" || status() === "streaming") {
return;
}
const lastMessage = currentMessages[currentMessages.length - 1];
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
triggerRequest({ messages: currentMessages });
}
};
const isLoading = createMemo2(
() => status() === "submitted" || status() === "streaming"
);
return {
// TODO next major release: replace with direct message store access (breaking change)
messages: () => messagesStore,
id: chatId(),
append,
error,
reload,
stop,
setMessages,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading,
status,
data: streamData,
setData,
addToolResult
};
}
// src/use-completion.ts
import { callCompletionApi } from "@ai-sdk/ui-utils";
import {
createEffect as createEffect2,
createMemo as createMemo3,
createSignal as createSignal2,
createUniqueId
} from "solid-js";
var completionCache = new ReactiveLRU();
function useCompletion(rawUseCompletionOptions = {}) {
const useCompletionOptions = createMemo3(
() => convertToAccessorOptions(rawUseCompletionOptions)
);
const api = createMemo3(
() => useCompletionOptions().api?.() ?? "/api/completion"
);
const idKey = createMemo3(
() => useCompletionOptions().id?.() ?? `completion-${createUniqueId()}`
);
const completionKey = createMemo3(() => `${api()}|${idKey()}|completion`);
const completion = createMemo3(
() => completionCache.get(completionKey()) ?? useCompletionOptions().initialCompletion?.() ?? ""
);
const mutate = (data) => {
completionCache.set(completionKey(), data);
};
const [error, setError] = createSignal2(void 0);
const [streamData, setStreamData] = createSignal2(
void 0
);
const [isLoading, setIsLoading] = createSignal2(false);
const [abortController, setAbortController] = createSignal2(null);
let extraMetadata = {
credentials: useCompletionOptions().credentials?.(),
headers: useCompletionOptions().headers?.(),
body: useCompletionOptions().body?.()
};
createEffect2(() => {
extraMetadata = {
credentials: useCompletionOptions().credentials?.(),
headers: useCompletionOptions().headers?.(),
body: useCompletionOptions().body?.()
};
});
const complete = async (prompt, options) => {
const existingData = streamData() ?? [];
return callCompletionApi({
api: api(),
prompt,
credentials: useCompletionOptions().credentials?.(),
headers: { ...extraMetadata.headers, ...options?.headers },
body: {
...extraMetadata.body,
...options?.body
},
streamProtocol: useCompletionOptions().streamProtocol?.(),
setCompletion: mutate,
setLoading: setIsLoading,
setError,
setAbortController,
onResponse: useCompletionOptions().onResponse?.(),
onFinish: useCompletionOptions().onFinish?.(),
onError: useCompletionOptions().onError?.(),
onData: (data) => {
setStreamData([...existingData, ...data ?? []]);
},
fetch: useCompletionOptions().fetch?.()
});
};
const stop = () => {
if (abortController()) {
abortController().abort();
}
};
const setCompletion = (completion2) => {
mutate(completion2);
};
const [input, setInput] = createSignal2(
useCompletionOptions().initialInput?.() ?? ""
);
const handleInputChange = (event) => {
setInput(event.target.value);
};
const handleSubmit = (event) => {
event?.preventDefault?.();
const inputValue = input();
return inputValue ? complete(inputValue) : void 0;
};
return {
completion,
complete,
error,
stop,
setCompletion,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading,
data: streamData
};
}
// src/use-object.ts
import {
isAbortError,
safeValidateTypes
} from "@ai-sdk/provider-utils";
import {
asSchema,
isDeepEqualData,
parsePartialJson
} from "@ai-sdk/ui-utils";
import { createMemo as createMemo4, createSignal as createSignal3, createUniqueId as createUniqueId2 } from "solid-js";
var getOriginalFetch = () => fetch;
var objectCache = new ReactiveLRU();
function useObject(rawUseObjectOptions) {
const useObjectOptions = createMemo4(
() => convertToAccessorOptions(rawUseObjectOptions)
);
const idKey = createMemo4(
() => useObjectOptions().id?.() ?? `object-${createUniqueId2()}`
);
const data = createMemo4(
() => objectCache.get(idKey()) ?? useObjectOptions().initialValue?.()
);
const mutate = (value) => {
objectCache.set(idKey(), value);
};
const [error, setError] = createSignal3();
const [isLoading, setIsLoading] = createSignal3(false);
let abortControllerRef = null;
const stop = () => {
try {
abortControllerRef?.abort();
} catch (ignored) {
} finally {
setIsLoading(false);
abortControllerRef = null;
}
};
const submit = async (input) => {
try {
mutate(void 0);
setIsLoading(true);
setError(void 0);
const abortController = new AbortController();
abortControllerRef = abortController;
const actualFetch = fetch ?? getOriginalFetch();
const response = await actualFetch(useObjectOptions().api(), {
method: "POST",
headers: {
"Content-Type": "application/json",
...useObjectOptions().headers?.()
},
credentials: useObjectOptions().credentials?.(),
signal: abortController.signal,
body: JSON.stringify(input)
});
if (!response.ok) {
throw new Error(
await response.text() ?? "Failed to fetch the response."
);
}
if (response.body == null) {
throw new Error("The response body is empty.");
}
let accumulatedText = "";
let latestObject = void 0;
await response.body.pipeThrough(new TextDecoderStream()).pipeTo(
new WritableStream({
write(chunk) {
accumulatedText += chunk;
const { value } = parsePartialJson(accumulatedText);
const currentObject = value;
if (!isDeepEqualData(latestObject, currentObject)) {
latestObject = currentObject;
mutate(currentObject);
}
},
close() {
setIsLoading(false);
abortControllerRef = null;
const onFinish = useObjectOptions().onFinish?.();
if (onFinish != null) {
const validationResult = safeValidateTypes({
value: latestObject,
schema: asSchema(useObjectOptions().schema())
});
onFinish(
validationResult.success ? { object: validationResult.value, error: void 0 } : { object: void 0, error: validationResult.error }
);
}
}
})
);
} catch (error2) {
if (isAbortError(error2)) {
return;
}
const onError = useObjectOptions().onError?.();
if (onError && error2 instanceof Error) {
onError(error2);
}
setIsLoading(false);
setError(error2 instanceof Error ? error2 : new Error(String(error2)));
}
};
return {
submit,
object: data,
error,
isLoading,
stop
};
}
var experimental_useObject = useObject;
// src/use-assistant.ts
import { isAbortError as isAbortError2 } from "@ai-sdk/provider-utils";
import {
generateId,
processAssistantStream
} from "@ai-sdk/ui-utils";
import { createMemo as createMemo5, createSignal as createSignal4 } from "solid-js";
import { createStore as createStore2 } from "solid-js/store";
var getOriginalFetch2 = () => fetch;
function useAssistant(rawUseAssistantOptions) {
const useAssistantOptions = createMemo5(
() => convertToAccessorOptions(rawUseAssistantOptions)
);
const [messages, setMessages] = createStore2([]);
const [input, setInput] = createSignal4("");
const [currentThreadId, setCurrentThreadId] = createSignal4();
const [status, setStatus] = createSignal4("awaiting_message");
const [error, setError] = createSignal4();
const handleInputChange = (event) => {
setInput(event.target.value);
};
let abortControllerRef = null;
const stop = () => {
if (abortControllerRef) {
abortControllerRef?.abort();
abortControllerRef = null;
}
};
const append = async (message, requestOptions) => {
setStatus("in_progress");
setMessages((messages2) => [
...messages2,
{
...message,
id: message.id ?? generateId()
}
]);
setInput("");
const abortController = new AbortController();
try {
abortControllerRef = abortController;
const actualFetch = fetch ?? getOriginalFetch2();
const response = await actualFetch(useAssistantOptions().api(), {
method: "POST",
credentials: useAssistantOptions().credentials?.(),
signal: abortController.signal,
headers: {
"Content-Type": "application/json",
...useAssistantOptions().headers?.()
},
body: JSON.stringify({
...useAssistantOptions().body?.(),
// always use user-provided threadId when available:
threadId: useAssistantOptions().threadId?.() ?? currentThreadId(),
message: message.content,
// optional request data:
data: requestOptions?.data
})
});
if (!response.ok) {
throw new Error(
await response.text() ?? "Failed to fetch the assistant response."
);
}
if (response.body == null) {
throw new Error("The response body is empty.");
}
await processAssistantStream({
stream: response.body,
onAssistantMessagePart(value) {
setMessages((messages2) => [
...messages2,
{
id: value.id,
role: value.role,
content: value.content[0].text.value,
parts: []
}
]);
},
onTextPart(value) {
setMessages((messages2) => {
const lastMessage = messages2[messages2.length - 1];
return [
...messages2.slice(0, messages2.length - 1),
{
id: lastMessage.id,
role: lastMessage.role,
content: lastMessage.content + value,
parts: lastMessage.parts
}
];
});
},
onAssistantControlDataPart(value) {
setCurrentThreadId(value.threadId);
setMessages((messages2) => {
const lastMessage = messages2[messages2.length - 1];
lastMessage.id = value.messageId;
return [...messages2.slice(0, messages2.length - 1), lastMessage];
});
},
onDataMessagePart(value) {
setMessages((messages2) => [
...messages2,
{
id: value.id ?? generateId(),
role: "data",
content: "",
data: value.data,
parts: []
}
]);
},
onErrorPart(value) {
setError(new Error(value));
}
});
} catch (error2) {
if (isAbortError2(error2) && abortController.signal.aborted) {
abortControllerRef = null;
return;
}
const onError = useAssistantOptions().onError?.();
if (onError && error2 instanceof Error) {
onError(error2);
}
setError(error2);
} finally {
abortControllerRef = null;
setStatus("awaiting_message");
}
};
const submitMessage = async (event, requestOptions) => {
event?.preventDefault?.();
if (input() === "") {
return;
}
append({ role: "user", content: input(), parts: [] }, requestOptions);
};
const setThreadId = (threadId) => {
setCurrentThreadId(threadId);
setMessages([]);
};
return {
append,
messages,
setMessages,
threadId: currentThreadId,
setThreadId,
input,
setInput,
handleInputChange,
submitMessage,
status,
error,
stop
};
}
export {
experimental_useObject,
useAssistant,
useChat,
useCompletion
};
//# sourceMappingURL=index.mjs.map