@ai-sdk/react
Version:
[React](https://react.dev/) UI components for the [AI SDK](https://sdk.vercel.ai/docs):
817 lines (811 loc) • 23.5 kB
JavaScript
// src/use-assistant.ts
import { isAbortError } from "@ai-sdk/provider-utils";
import {
generateId,
processAssistantStream
} from "@ai-sdk/ui-utils";
import { useCallback, useRef, useState } from "react";
var getOriginalFetch = () => fetch;
function useAssistant({
api,
threadId: threadIdParam,
credentials,
headers,
body,
onError,
fetch: fetch2
}) {
const [messages, setMessages] = useState([]);
const [input, setInput] = useState("");
const [currentThreadId, setCurrentThreadId] = useState(
void 0
);
const [status, setStatus] = useState("awaiting_message");
const [error, setError] = useState(void 0);
const handleInputChange = (event) => {
setInput(event.target.value);
};
const abortControllerRef = useRef(null);
const stop = useCallback(() => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
}, []);
const append = async (message, requestOptions) => {
var _a, _b;
setStatus("in_progress");
setMessages((messages2) => {
var _a2;
return [
...messages2,
{
...message,
id: (_a2 = message.id) != null ? _a2 : generateId()
}
];
});
setInput("");
const abortController = new AbortController();
try {
abortControllerRef.current = abortController;
const actualFetch = fetch2 != null ? fetch2 : getOriginalFetch();
const response = await actualFetch(api, {
method: "POST",
credentials,
signal: abortController.signal,
headers: { "Content-Type": "application/json", ...headers },
body: JSON.stringify({
...body,
// always use user-provided threadId when available:
threadId: (_a = threadIdParam != null ? threadIdParam : currentThreadId) != null ? _a : null,
message: message.content,
// optional request data:
data: requestOptions == null ? void 0 : requestOptions.data
})
});
if (!response.ok) {
throw new Error(
(_b = await response.text()) != null ? _b : "Failed to fetch the assistant response."
);
}
if (response.body == null) {
throw new Error("The response body is empty.");
}
await processAssistantStream({
stream: response.body,
onAssistantMessagePart(value) {
setMessages((messages2) => [
...messages2,
{
id: value.id,
role: value.role,
content: value.content[0].text.value,
parts: []
}
]);
},
onTextPart(value) {
setMessages((messages2) => {
const lastMessage = messages2[messages2.length - 1];
return [
...messages2.slice(0, messages2.length - 1),
{
id: lastMessage.id,
role: lastMessage.role,
content: lastMessage.content + value,
parts: lastMessage.parts
}
];
});
},
onAssistantControlDataPart(value) {
setCurrentThreadId(value.threadId);
setMessages((messages2) => {
const lastMessage = messages2[messages2.length - 1];
lastMessage.id = value.messageId;
return [...messages2.slice(0, messages2.length - 1), lastMessage];
});
},
onDataMessagePart(value) {
setMessages((messages2) => {
var _a2;
return [
...messages2,
{
id: (_a2 = value.id) != null ? _a2 : generateId(),
role: "data",
content: "",
data: value.data,
parts: []
}
];
});
},
onErrorPart(value) {
setError(new Error(value));
}
});
} catch (error2) {
if (isAbortError(error2) && abortController.signal.aborted) {
abortControllerRef.current = null;
return;
}
if (onError && error2 instanceof Error) {
onError(error2);
}
setError(error2);
} finally {
abortControllerRef.current = null;
setStatus("awaiting_message");
}
};
const submitMessage = async (event, requestOptions) => {
var _a;
(_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event);
if (input === "") {
return;
}
append({ role: "user", content: input, parts: [] }, requestOptions);
};
const setThreadId = (threadId) => {
setCurrentThreadId(threadId);
setMessages([]);
};
return {
append,
messages,
setMessages,
threadId: currentThreadId,
setThreadId,
input,
setInput,
handleInputChange,
submitMessage,
status,
error,
stop
};
}
// src/use-chat.ts
import {
callChatApi,
extractMaxToolInvocationStep,
fillMessageParts,
generateId as generateIdFunc,
getMessageParts,
isAssistantMessageWithCompletedToolCalls,
prepareAttachmentsForRequest,
shouldResubmitMessages,
updateToolCallResult
} from "@ai-sdk/ui-utils";
import { useCallback as useCallback2, useEffect as useEffect2, useMemo, useRef as useRef2, useState as useState3 } from "react";
import useSWR from "swr";
// src/throttle.ts
import throttleFunction from "throttleit";
function throttle(fn, waitMs) {
return waitMs != null ? throttleFunction(fn, waitMs) : fn;
}
// src/util/use-stable-value.ts
import { isDeepEqualData } from "@ai-sdk/ui-utils";
import { useEffect, useState as useState2 } from "react";
function useStableValue(latestValue) {
const [value, setValue] = useState2(latestValue);
useEffect(() => {
if (!isDeepEqualData(latestValue, value)) {
setValue(latestValue);
}
}, [latestValue, value]);
return value;
}
// src/use-chat.ts
function useChat({
api = "/api/chat",
id,
initialMessages,
initialInput = "",
sendExtraMessageFields,
onToolCall,
experimental_prepareRequestBody,
maxSteps = 1,
streamProtocol = "data",
onResponse,
onFinish,
onError,
credentials,
headers,
body,
generateId: generateId2 = generateIdFunc,
fetch: fetch2,
keepLastMessageOnError = true,
experimental_throttle: throttleWaitMs
} = {}) {
const [hookId] = useState3(generateId2);
const chatId = id != null ? id : hookId;
const chatKey = typeof api === "string" ? [api, chatId] : chatId;
const stableInitialMessages = useStableValue(initialMessages != null ? initialMessages : []);
const processedInitialMessages = useMemo(
() => fillMessageParts(stableInitialMessages),
[stableInitialMessages]
);
const { data: messages, mutate } = useSWR(
[chatKey, "messages"],
null,
{ fallbackData: processedInitialMessages }
);
const messagesRef = useRef2(messages || []);
useEffect2(() => {
messagesRef.current = messages || [];
}, [messages]);
const { data: streamData, mutate: mutateStreamData } = useSWR([chatKey, "streamData"], null);
const streamDataRef = useRef2(streamData);
useEffect2(() => {
streamDataRef.current = streamData;
}, [streamData]);
const { data: status = "ready", mutate: mutateStatus } = useSWR([chatKey, "status"], null);
const { data: error = void 0, mutate: setError } = useSWR([chatKey, "error"], null);
const abortControllerRef = useRef2(null);
const extraMetadataRef = useRef2({
credentials,
headers,
body
});
useEffect2(() => {
extraMetadataRef.current = {
credentials,
headers,
body
};
}, [credentials, headers, body]);
const triggerRequest = useCallback2(
async (chatRequest) => {
var _a, _b;
mutateStatus("submitted");
setError(void 0);
const chatMessages = fillMessageParts(chatRequest.messages);
const messageCount = chatMessages.length;
const maxStep = extractMaxToolInvocationStep(
(_a = chatMessages[chatMessages.length - 1]) == null ? void 0 : _a.toolInvocations
);
try {
const abortController = new AbortController();
abortControllerRef.current = abortController;
const throttledMutate = throttle(mutate, throttleWaitMs);
const throttledMutateStreamData = throttle(
mutateStreamData,
throttleWaitMs
);
const previousMessages = messagesRef.current;
throttledMutate(chatMessages, false);
const constructedMessagesPayload = sendExtraMessageFields ? chatMessages : chatMessages.map(
({
role,
content,
experimental_attachments,
data,
annotations,
toolInvocations,
parts
}) => ({
role,
content,
...experimental_attachments !== void 0 && {
experimental_attachments
},
...data !== void 0 && { data },
...annotations !== void 0 && { annotations },
...toolInvocations !== void 0 && { toolInvocations },
...parts !== void 0 && { parts }
})
);
const existingData = streamDataRef.current;
await callChatApi({
api,
body: (_b = experimental_prepareRequestBody == null ? void 0 : experimental_prepareRequestBody({
id: chatId,
messages: chatMessages,
requestData: chatRequest.data,
requestBody: chatRequest.body
})) != null ? _b : {
id: chatId,
messages: constructedMessagesPayload,
data: chatRequest.data,
...extraMetadataRef.current.body,
...chatRequest.body
},
streamProtocol,
credentials: extraMetadataRef.current.credentials,
headers: {
...extraMetadataRef.current.headers,
...chatRequest.headers
},
abortController: () => abortControllerRef.current,
restoreMessagesOnFailure() {
if (!keepLastMessageOnError) {
throttledMutate(previousMessages, false);
}
},
onResponse,
onUpdate({ message, data, replaceLastMessage }) {
mutateStatus("streaming");
throttledMutate(
[
...replaceLastMessage ? chatMessages.slice(0, chatMessages.length - 1) : chatMessages,
message
],
false
);
if (data == null ? void 0 : data.length) {
throttledMutateStreamData(
[...existingData != null ? existingData : [], ...data],
false
);
}
},
onToolCall,
onFinish,
generateId: generateId2,
fetch: fetch2,
lastMessage: chatMessages[chatMessages.length - 1]
});
abortControllerRef.current = null;
mutateStatus("ready");
} catch (err) {
if (err.name === "AbortError") {
abortControllerRef.current = null;
mutateStatus("ready");
return null;
}
if (onError && err instanceof Error) {
onError(err);
}
setError(err);
mutateStatus("error");
}
const messages2 = messagesRef.current;
if (shouldResubmitMessages({
originalMaxToolInvocationStep: maxStep,
originalMessageCount: messageCount,
maxSteps,
messages: messages2
})) {
await triggerRequest({ messages: messages2 });
}
},
[
mutate,
mutateStatus,
api,
extraMetadataRef,
onResponse,
onFinish,
onError,
setError,
mutateStreamData,
streamDataRef,
streamProtocol,
sendExtraMessageFields,
experimental_prepareRequestBody,
onToolCall,
maxSteps,
messagesRef,
abortControllerRef,
generateId2,
fetch2,
keepLastMessageOnError,
throttleWaitMs,
chatId
]
);
const append = useCallback2(
async (message, {
data,
headers: headers2,
body: body2,
experimental_attachments
} = {}) => {
var _a, _b;
const attachmentsForRequest = await prepareAttachmentsForRequest(
experimental_attachments
);
const messages2 = messagesRef.current.concat({
...message,
id: (_a = message.id) != null ? _a : generateId2(),
createdAt: (_b = message.createdAt) != null ? _b : /* @__PURE__ */ new Date(),
experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0,
parts: getMessageParts(message)
});
return triggerRequest({ messages: messages2, headers: headers2, body: body2, data });
},
[triggerRequest, generateId2]
);
const reload = useCallback2(
async ({ data, headers: headers2, body: body2 } = {}) => {
const messages2 = messagesRef.current;
if (messages2.length === 0) {
return null;
}
const lastMessage = messages2[messages2.length - 1];
return triggerRequest({
messages: lastMessage.role === "assistant" ? messages2.slice(0, -1) : messages2,
headers: headers2,
body: body2,
data
});
},
[triggerRequest]
);
const stop = useCallback2(() => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
}, []);
const setMessages = useCallback2(
(messages2) => {
if (typeof messages2 === "function") {
messages2 = messages2(messagesRef.current);
}
const messagesWithParts = fillMessageParts(messages2);
mutate(messagesWithParts, false);
messagesRef.current = messagesWithParts;
},
[mutate]
);
const setData = useCallback2(
(data) => {
if (typeof data === "function") {
data = data(streamDataRef.current);
}
mutateStreamData(data, false);
streamDataRef.current = data;
},
[mutateStreamData]
);
const [input, setInput] = useState3(initialInput);
const handleSubmit = useCallback2(
async (event, options = {}, metadata) => {
var _a;
(_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event);
if (!input && !options.allowEmptySubmit)
return;
if (metadata) {
extraMetadataRef.current = {
...extraMetadataRef.current,
...metadata
};
}
const attachmentsForRequest = await prepareAttachmentsForRequest(
options.experimental_attachments
);
const messages2 = messagesRef.current.concat({
id: generateId2(),
createdAt: /* @__PURE__ */ new Date(),
role: "user",
content: input,
experimental_attachments: attachmentsForRequest.length > 0 ? attachmentsForRequest : void 0,
parts: [{ type: "text", text: input }]
});
const chatRequest = {
messages: messages2,
headers: options.headers,
body: options.body,
data: options.data
};
triggerRequest(chatRequest);
setInput("");
},
[input, generateId2, triggerRequest]
);
const handleInputChange = (e) => {
setInput(e.target.value);
};
const addToolResult = useCallback2(
({ toolCallId, result }) => {
const currentMessages = messagesRef.current;
updateToolCallResult({
messages: currentMessages,
toolCallId,
toolResult: result
});
mutate(
[
...currentMessages.slice(0, currentMessages.length - 1),
{ ...currentMessages[currentMessages.length - 1] }
],
false
);
if (status === "submitted" || status === "streaming") {
return;
}
const lastMessage = currentMessages[currentMessages.length - 1];
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
triggerRequest({ messages: currentMessages });
}
},
[mutate, status, triggerRequest]
);
return {
messages: messages != null ? messages : [],
id: chatId,
setMessages,
data: streamData,
setData,
error,
append,
reload,
stop,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading: status === "submitted" || status === "streaming",
status,
addToolResult
};
}
// src/use-completion.ts
import {
callCompletionApi
} from "@ai-sdk/ui-utils";
import { useCallback as useCallback3, useEffect as useEffect3, useId, useRef as useRef3, useState as useState4 } from "react";
import useSWR2 from "swr";
function useCompletion({
api = "/api/completion",
id,
initialCompletion = "",
initialInput = "",
credentials,
headers,
body,
streamProtocol = "data",
fetch: fetch2,
onResponse,
onFinish,
onError,
experimental_throttle: throttleWaitMs
} = {}) {
const hookId = useId();
const completionId = id || hookId;
const { data, mutate } = useSWR2([api, completionId], null, {
fallbackData: initialCompletion
});
const { data: isLoading = false, mutate: mutateLoading } = useSWR2(
[completionId, "loading"],
null
);
const { data: streamData, mutate: mutateStreamData } = useSWR2([completionId, "streamData"], null);
const [error, setError] = useState4(void 0);
const completion = data;
const [abortController, setAbortController] = useState4(null);
const extraMetadataRef = useRef3({
credentials,
headers,
body
});
useEffect3(() => {
extraMetadataRef.current = {
credentials,
headers,
body
};
}, [credentials, headers, body]);
const triggerRequest = useCallback3(
async (prompt, options) => callCompletionApi({
api,
prompt,
credentials: extraMetadataRef.current.credentials,
headers: { ...extraMetadataRef.current.headers, ...options == null ? void 0 : options.headers },
body: {
...extraMetadataRef.current.body,
...options == null ? void 0 : options.body
},
streamProtocol,
fetch: fetch2,
// throttle streamed ui updates:
setCompletion: throttle(
(completion2) => mutate(completion2, false),
throttleWaitMs
),
onData: throttle(
(data2) => mutateStreamData([...streamData != null ? streamData : [], ...data2 != null ? data2 : []], false),
throttleWaitMs
),
setLoading: mutateLoading,
setError,
setAbortController,
onResponse,
onFinish,
onError
}),
[
mutate,
mutateLoading,
api,
extraMetadataRef,
setAbortController,
onResponse,
onFinish,
onError,
setError,
streamData,
streamProtocol,
fetch2,
mutateStreamData,
throttleWaitMs
]
);
const stop = useCallback3(() => {
if (abortController) {
abortController.abort();
setAbortController(null);
}
}, [abortController]);
const setCompletion = useCallback3(
(completion2) => {
mutate(completion2, false);
},
[mutate]
);
const complete = useCallback3(
async (prompt, options) => {
return triggerRequest(prompt, options);
},
[triggerRequest]
);
const [input, setInput] = useState4(initialInput);
const handleSubmit = useCallback3(
(event) => {
var _a;
(_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event);
return input ? complete(input) : void 0;
},
[input, complete]
);
const handleInputChange = useCallback3(
(e) => {
setInput(e.target.value);
},
[setInput]
);
return {
completion,
complete,
error,
setCompletion,
stop,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading,
data: streamData
};
}
// src/use-object.ts
import {
isAbortError as isAbortError2,
safeValidateTypes
} from "@ai-sdk/provider-utils";
import {
asSchema,
isDeepEqualData as isDeepEqualData2,
parsePartialJson
} from "@ai-sdk/ui-utils";
import { useCallback as useCallback4, useId as useId2, useRef as useRef4, useState as useState5 } from "react";
import useSWR3 from "swr";
var getOriginalFetch2 = () => fetch;
function useObject({
api,
id,
schema,
// required, in the future we will use it for validation
initialValue,
fetch: fetch2,
onError,
onFinish,
headers,
credentials
}) {
const hookId = useId2();
const completionId = id != null ? id : hookId;
const { data, mutate } = useSWR3(
[api, completionId],
null,
{ fallbackData: initialValue }
);
const [error, setError] = useState5(void 0);
const [isLoading, setIsLoading] = useState5(false);
const abortControllerRef = useRef4(null);
const stop = useCallback4(() => {
var _a;
try {
(_a = abortControllerRef.current) == null ? void 0 : _a.abort();
} catch (ignored) {
} finally {
setIsLoading(false);
abortControllerRef.current = null;
}
}, []);
const submit = async (input) => {
var _a;
try {
mutate(void 0);
setIsLoading(true);
setError(void 0);
const abortController = new AbortController();
abortControllerRef.current = abortController;
const actualFetch = fetch2 != null ? fetch2 : getOriginalFetch2();
const response = await actualFetch(api, {
method: "POST",
headers: {
"Content-Type": "application/json",
...headers
},
credentials,
signal: abortController.signal,
body: JSON.stringify(input)
});
if (!response.ok) {
throw new Error(
(_a = await response.text()) != null ? _a : "Failed to fetch the response."
);
}
if (response.body == null) {
throw new Error("The response body is empty.");
}
let accumulatedText = "";
let latestObject = void 0;
await response.body.pipeThrough(new TextDecoderStream()).pipeTo(
new WritableStream({
write(chunk) {
accumulatedText += chunk;
const { value } = parsePartialJson(accumulatedText);
const currentObject = value;
if (!isDeepEqualData2(latestObject, currentObject)) {
latestObject = currentObject;
mutate(currentObject);
}
},
close() {
setIsLoading(false);
abortControllerRef.current = null;
if (onFinish != null) {
const validationResult = safeValidateTypes({
value: latestObject,
schema: asSchema(schema)
});
onFinish(
validationResult.success ? { object: validationResult.value, error: void 0 } : { object: void 0, error: validationResult.error }
);
}
}
})
);
} catch (error2) {
if (isAbortError2(error2)) {
return;
}
if (onError && error2 instanceof Error) {
onError(error2);
}
setIsLoading(false);
setError(error2 instanceof Error ? error2 : new Error(String(error2)));
}
};
return {
submit,
object: data,
error,
isLoading,
stop
};
}
var experimental_useObject = useObject;
export {
experimental_useObject,
useAssistant,
useChat,
useCompletion
};
//# sourceMappingURL=index.mjs.map