@ai-sdk/react
Version:
[React](https://react.dev/) UI components for the [AI SDK](https://ai-sdk.dev/docs):
466 lines (459 loc) • 14 kB
JavaScript
var __accessCheck = (obj, member, msg) => {
if (!member.has(obj))
throw TypeError("Cannot " + msg);
};
var __privateGet = (obj, member, getter) => {
__accessCheck(obj, member, "read from private field");
return getter ? getter.call(obj) : member.get(obj);
};
var __privateAdd = (obj, member, value) => {
if (member.has(obj))
throw TypeError("Cannot add the same private member more than once");
member instanceof WeakSet ? member.add(obj) : member.set(obj, value);
};
var __privateSet = (obj, member, value, setter) => {
__accessCheck(obj, member, "write to private field");
setter ? setter.call(obj, value) : member.set(obj, value);
return value;
};
// src/use-chat.ts
import { useCallback, useEffect, useRef, useSyncExternalStore } from "react";
// src/chat.react.ts
import { AbstractChat } from "ai";
// src/throttle.ts
import throttleFunction from "throttleit";
function throttle(fn, waitMs) {
return waitMs != null ? throttleFunction(fn, waitMs) : fn;
}
// src/chat.react.ts
var _messages, _status, _error, _messagesCallbacks, _statusCallbacks, _errorCallbacks, _callMessagesCallbacks, _callStatusCallbacks, _callErrorCallbacks;
var ReactChatState = class {
constructor(initialMessages = []) {
__privateAdd(this, _messages, void 0);
__privateAdd(this, _status, "ready");
__privateAdd(this, _error, void 0);
__privateAdd(this, _messagesCallbacks, /* @__PURE__ */ new Set());
__privateAdd(this, _statusCallbacks, /* @__PURE__ */ new Set());
__privateAdd(this, _errorCallbacks, /* @__PURE__ */ new Set());
this.pushMessage = (message) => {
__privateSet(this, _messages, __privateGet(this, _messages).concat(message));
__privateGet(this, _callMessagesCallbacks).call(this);
};
this.popMessage = () => {
__privateSet(this, _messages, __privateGet(this, _messages).slice(0, -1));
__privateGet(this, _callMessagesCallbacks).call(this);
};
this.replaceMessage = (index, message) => {
__privateSet(this, _messages, [
...__privateGet(this, _messages).slice(0, index),
// We deep clone the message here to ensure the new React Compiler (currently in RC) detects deeply nested parts/metadata changes:
this.snapshot(message),
...__privateGet(this, _messages).slice(index + 1)
]);
__privateGet(this, _callMessagesCallbacks).call(this);
};
this.snapshot = (value) => structuredClone(value);
this["~registerMessagesCallback"] = (onChange, throttleWaitMs) => {
const callback = throttleWaitMs ? throttle(onChange, throttleWaitMs) : onChange;
__privateGet(this, _messagesCallbacks).add(callback);
return () => {
__privateGet(this, _messagesCallbacks).delete(callback);
};
};
this["~registerStatusCallback"] = (onChange) => {
__privateGet(this, _statusCallbacks).add(onChange);
return () => {
__privateGet(this, _statusCallbacks).delete(onChange);
};
};
this["~registerErrorCallback"] = (onChange) => {
__privateGet(this, _errorCallbacks).add(onChange);
return () => {
__privateGet(this, _errorCallbacks).delete(onChange);
};
};
__privateAdd(this, _callMessagesCallbacks, () => {
__privateGet(this, _messagesCallbacks).forEach((callback) => callback());
});
__privateAdd(this, _callStatusCallbacks, () => {
__privateGet(this, _statusCallbacks).forEach((callback) => callback());
});
__privateAdd(this, _callErrorCallbacks, () => {
__privateGet(this, _errorCallbacks).forEach((callback) => callback());
});
__privateSet(this, _messages, initialMessages);
}
get status() {
return __privateGet(this, _status);
}
set status(newStatus) {
__privateSet(this, _status, newStatus);
__privateGet(this, _callStatusCallbacks).call(this);
}
get error() {
return __privateGet(this, _error);
}
set error(newError) {
__privateSet(this, _error, newError);
__privateGet(this, _callErrorCallbacks).call(this);
}
get messages() {
return __privateGet(this, _messages);
}
set messages(newMessages) {
__privateSet(this, _messages, [...newMessages]);
__privateGet(this, _callMessagesCallbacks).call(this);
}
};
_messages = new WeakMap();
_status = new WeakMap();
_error = new WeakMap();
_messagesCallbacks = new WeakMap();
_statusCallbacks = new WeakMap();
_errorCallbacks = new WeakMap();
_callMessagesCallbacks = new WeakMap();
_callStatusCallbacks = new WeakMap();
_callErrorCallbacks = new WeakMap();
var _state;
var Chat = class extends AbstractChat {
constructor({ messages, ...init }) {
const state = new ReactChatState(messages);
super({ ...init, state });
__privateAdd(this, _state, void 0);
this["~registerMessagesCallback"] = (onChange, throttleWaitMs) => __privateGet(this, _state)["~registerMessagesCallback"](onChange, throttleWaitMs);
this["~registerStatusCallback"] = (onChange) => __privateGet(this, _state)["~registerStatusCallback"](onChange);
this["~registerErrorCallback"] = (onChange) => __privateGet(this, _state)["~registerErrorCallback"](onChange);
__privateSet(this, _state, state);
}
};
_state = new WeakMap();
// src/use-chat.ts
function useChat({
experimental_throttle: throttleWaitMs,
resume = false,
...options
} = {}) {
const chatRef = useRef(
"chat" in options ? options.chat : new Chat(options)
);
const shouldRecreateChat = "chat" in options && options.chat !== chatRef.current || "id" in options && chatRef.current.id !== options.id;
if (shouldRecreateChat) {
chatRef.current = "chat" in options ? options.chat : new Chat(options);
}
const optionsId = "id" in options ? options.id : null;
const subscribeToMessages = useCallback(
(update) => chatRef.current["~registerMessagesCallback"](update, throttleWaitMs),
// optionsId is required to trigger re-subscription when the chat ID changes
// eslint-disable-next-line react-hooks/exhaustive-deps
[throttleWaitMs, optionsId]
);
const messages = useSyncExternalStore(
subscribeToMessages,
() => chatRef.current.messages,
() => chatRef.current.messages
);
const status = useSyncExternalStore(
chatRef.current["~registerStatusCallback"],
() => chatRef.current.status,
() => chatRef.current.status
);
const error = useSyncExternalStore(
chatRef.current["~registerErrorCallback"],
() => chatRef.current.error,
() => chatRef.current.error
);
const setMessages = useCallback(
(messagesParam) => {
if (typeof messagesParam === "function") {
messagesParam = messagesParam(chatRef.current.messages);
}
chatRef.current.messages = messagesParam;
},
[chatRef]
);
useEffect(() => {
if (resume) {
chatRef.current.resumeStream();
}
}, [resume, chatRef]);
return {
id: chatRef.current.id,
messages,
setMessages,
sendMessage: chatRef.current.sendMessage,
regenerate: chatRef.current.regenerate,
clearError: chatRef.current.clearError,
stop: chatRef.current.stop,
error,
resumeStream: chatRef.current.resumeStream,
status,
/**
* @deprecated Use `addToolOutput` instead.
*/
addToolResult: chatRef.current.addToolOutput,
addToolOutput: chatRef.current.addToolOutput
};
}
// src/use-completion.ts
import {
callCompletionApi
} from "ai";
import { useCallback as useCallback2, useEffect as useEffect2, useId, useRef as useRef2, useState } from "react";
import useSWR from "swr";
function useCompletion({
api = "/api/completion",
id,
initialCompletion = "",
initialInput = "",
credentials,
headers,
body,
streamProtocol = "data",
fetch: fetch2,
onFinish,
onError,
experimental_throttle: throttleWaitMs
} = {}) {
const hookId = useId();
const completionId = id || hookId;
const { data, mutate } = useSWR([api, completionId], null, {
fallbackData: initialCompletion
});
const { data: isLoading = false, mutate: mutateLoading } = useSWR(
[completionId, "loading"],
null
);
const [error, setError] = useState(void 0);
const completion = data;
const [abortController, setAbortController] = useState(null);
const extraMetadataRef = useRef2({
credentials,
headers,
body
});
useEffect2(() => {
extraMetadataRef.current = {
credentials,
headers,
body
};
}, [credentials, headers, body]);
const triggerRequest = useCallback2(
async (prompt, options) => callCompletionApi({
api,
prompt,
credentials: extraMetadataRef.current.credentials,
headers: { ...extraMetadataRef.current.headers, ...options == null ? void 0 : options.headers },
body: {
...extraMetadataRef.current.body,
...options == null ? void 0 : options.body
},
streamProtocol,
fetch: fetch2,
// throttle streamed ui updates:
setCompletion: throttle(
(completion2) => mutate(completion2, false),
throttleWaitMs
),
setLoading: mutateLoading,
setError,
setAbortController,
onFinish,
onError
}),
[
mutate,
mutateLoading,
api,
extraMetadataRef,
setAbortController,
onFinish,
onError,
setError,
streamProtocol,
fetch2,
throttleWaitMs
]
);
const stop = useCallback2(() => {
if (abortController) {
abortController.abort();
setAbortController(null);
}
}, [abortController]);
const setCompletion = useCallback2(
(completion2) => {
mutate(completion2, false);
},
[mutate]
);
const complete = useCallback2(
async (prompt, options) => {
return triggerRequest(prompt, options);
},
[triggerRequest]
);
const [input, setInput] = useState(initialInput);
const handleSubmit = useCallback2(
(event) => {
var _a;
(_a = event == null ? void 0 : event.preventDefault) == null ? void 0 : _a.call(event);
return input ? complete(input) : void 0;
},
[input, complete]
);
const handleInputChange = useCallback2(
(e) => {
setInput(e.target.value);
},
[setInput]
);
return {
completion,
complete,
error,
setCompletion,
stop,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading
};
}
// src/use-object.ts
import {
isAbortError,
safeValidateTypes
} from "@ai-sdk/provider-utils";
import {
asSchema,
isDeepEqualData,
parsePartialJson
} from "ai";
import { useCallback as useCallback3, useId as useId2, useRef as useRef3, useState as useState2 } from "react";
import useSWR2 from "swr";
var getOriginalFetch = () => fetch;
function useObject({
api,
id,
schema,
// required, in the future we will use it for validation
initialValue,
fetch: fetch2,
onError,
onFinish,
headers,
credentials
}) {
const hookId = useId2();
const completionId = id != null ? id : hookId;
const { data, mutate } = useSWR2(
[api, completionId],
null,
{ fallbackData: initialValue }
);
const [error, setError] = useState2(void 0);
const [isLoading, setIsLoading] = useState2(false);
const abortControllerRef = useRef3(null);
const stop = useCallback3(() => {
var _a;
try {
(_a = abortControllerRef.current) == null ? void 0 : _a.abort();
} catch (ignored) {
} finally {
setIsLoading(false);
abortControllerRef.current = null;
}
}, []);
const submit = async (input) => {
var _a;
try {
clearObject();
setIsLoading(true);
const abortController = new AbortController();
abortControllerRef.current = abortController;
const actualFetch = fetch2 != null ? fetch2 : getOriginalFetch();
const response = await actualFetch(api, {
method: "POST",
headers: {
"Content-Type": "application/json",
...headers
},
credentials,
signal: abortController.signal,
body: JSON.stringify(input)
});
if (!response.ok) {
throw new Error(
(_a = await response.text()) != null ? _a : "Failed to fetch the response."
);
}
if (response.body == null) {
throw new Error("The response body is empty.");
}
let accumulatedText = "";
let latestObject = void 0;
await response.body.pipeThrough(new TextDecoderStream()).pipeTo(
new WritableStream({
async write(chunk) {
accumulatedText += chunk;
const { value } = await parsePartialJson(accumulatedText);
const currentObject = value;
if (!isDeepEqualData(latestObject, currentObject)) {
latestObject = currentObject;
mutate(currentObject);
}
},
async close() {
setIsLoading(false);
abortControllerRef.current = null;
if (onFinish != null) {
const validationResult = await safeValidateTypes({
value: latestObject,
schema: asSchema(schema)
});
onFinish(
validationResult.success ? { object: validationResult.value, error: void 0 } : { object: void 0, error: validationResult.error }
);
}
}
})
);
} catch (error2) {
if (isAbortError(error2)) {
return;
}
if (onError && error2 instanceof Error) {
onError(error2);
}
setIsLoading(false);
setError(error2 instanceof Error ? error2 : new Error(String(error2)));
}
};
const clear = () => {
stop();
clearObject();
};
const clearObject = () => {
setError(void 0);
setIsLoading(false);
mutate(void 0);
};
return {
submit,
object: data,
error,
isLoading,
stop,
clear
};
}
var experimental_useObject = useObject;
export {
Chat,
experimental_useObject,
useChat,
useCompletion
};
//# sourceMappingURL=index.mjs.map