@ai-sdk/rsc
Version:
[React Server Components](https://react.dev/reference/rsc/server-components) for the [AI SDK](https://ai-sdk.dev/docs):
739 lines (726 loc) • 20.4 kB
JavaScript
// src/ai-state.tsx
import * as jsondiffpatch from "jsondiffpatch";
import { AsyncLocalStorage } from "async_hooks";
// src/util/create-resolvable-promise.ts
function createResolvablePromise() {
let resolve;
let reject;
const promise = new Promise((res, rej) => {
resolve = res;
reject = rej;
});
return {
promise,
resolve,
reject
};
}
// src/util/is-function.ts
var isFunction = (value) => typeof value === "function";
// src/ai-state.tsx
var asyncAIStateStorage = new AsyncLocalStorage();
function getAIStateStoreOrThrow(message) {
const store = asyncAIStateStorage.getStore();
if (!store) {
throw new Error(message);
}
return store;
}
function withAIState({ state, options }, fn) {
return asyncAIStateStorage.run(
{
currentState: JSON.parse(JSON.stringify(state)),
// deep clone object
originalState: state,
sealed: false,
options
},
fn
);
}
function getAIStateDeltaPromise() {
const store = getAIStateStoreOrThrow("Internal error occurred.");
return store.mutationDeltaPromise;
}
function sealMutableAIState() {
const store = getAIStateStoreOrThrow("Internal error occurred.");
store.sealed = true;
}
function getAIState(...args) {
const store = getAIStateStoreOrThrow(
"`getAIState` must be called within an AI Action."
);
if (args.length > 0) {
const key = args[0];
if (typeof store.currentState !== "object") {
throw new Error(
`You can't get the "${String(
key
)}" field from the AI state because it's not an object.`
);
}
return store.currentState[key];
}
return store.currentState;
}
function getMutableAIState(...args) {
const store = getAIStateStoreOrThrow(
"`getMutableAIState` must be called within an AI Action."
);
if (store.sealed) {
throw new Error(
"`getMutableAIState` must be called before returning from an AI Action. Please move it to the top level of the Action's function body."
);
}
if (!store.mutationDeltaPromise) {
const { promise, resolve } = createResolvablePromise();
store.mutationDeltaPromise = promise;
store.mutationDeltaResolve = resolve;
}
function doUpdate(newState, done) {
var _a, _b;
if (args.length > 0) {
if (typeof store.currentState !== "object") {
const key = args[0];
throw new Error(
`You can't modify the "${String(
key
)}" field of the AI state because it's not an object.`
);
}
}
if (isFunction(newState)) {
if (args.length > 0) {
store.currentState[args[0]] = newState(store.currentState[args[0]]);
} else {
store.currentState = newState(store.currentState);
}
} else {
if (args.length > 0) {
store.currentState[args[0]] = newState;
} else {
store.currentState = newState;
}
}
(_b = (_a = store.options).onSetAIState) == null ? void 0 : _b.call(_a, {
key: args.length > 0 ? args[0] : void 0,
state: store.currentState,
done
});
}
const mutableState = {
get: () => {
if (args.length > 0) {
const key = args[0];
if (typeof store.currentState !== "object") {
throw new Error(
`You can't get the "${String(
key
)}" field from the AI state because it's not an object.`
);
}
return store.currentState[key];
}
return store.currentState;
},
update: function update(newAIState) {
doUpdate(newAIState, false);
},
done: function done(...doneArgs) {
if (doneArgs.length > 0) {
doUpdate(doneArgs[0], true);
}
const delta = jsondiffpatch.diff(store.originalState, store.currentState);
store.mutationDeltaResolve(delta);
}
};
return mutableState;
}
// src/provider.tsx
import * as React from "react";
import { InternalAIProvider } from "./rsc-shared.mjs";
import { jsx } from "react/jsx-runtime";
async function innerAction({
action,
options
}, state, ...args) {
"use server";
return await withAIState(
{
state,
options
},
async () => {
const result = await action(...args);
sealMutableAIState();
return [getAIStateDeltaPromise(), result];
}
);
}
function wrapAction(action, options) {
return innerAction.bind(null, { action, options });
}
function createAI({
actions,
initialAIState,
initialUIState,
onSetAIState,
onGetUIState
}) {
const wrappedActions = {};
for (const name in actions) {
wrappedActions[name] = wrapAction(actions[name], {
onSetAIState
});
}
const wrappedSyncUIState = onGetUIState ? wrapAction(onGetUIState, {}) : void 0;
const AI = async (props) => {
var _a, _b;
if ("useState" in React) {
throw new Error(
"This component can only be used inside Server Components."
);
}
let uiState = (_a = props.initialUIState) != null ? _a : initialUIState;
let aiState = (_b = props.initialAIState) != null ? _b : initialAIState;
let aiStateDelta = void 0;
if (wrappedSyncUIState) {
const [newAIStateDelta, newUIState] = await wrappedSyncUIState(aiState);
if (newUIState !== void 0) {
aiStateDelta = newAIStateDelta;
uiState = newUIState;
}
}
return /* @__PURE__ */ jsx(
InternalAIProvider,
{
wrappedActions,
wrappedSyncUIState,
initialUIState: uiState,
initialAIState: aiState,
initialAIStatePatch: aiStateDelta,
children: props.children
}
);
};
return AI;
}
// src/stream-ui/stream-ui.tsx
import {
safeParseJSON
} from "@ai-sdk/provider-utils";
import {
InvalidToolInputError,
NoSuchToolError
} from "ai";
import {
standardizePrompt,
prepareToolsAndToolChoice,
prepareRetries,
prepareCallSettings,
convertToLanguageModelPrompt
} from "ai/internal";
// src/util/is-async-generator.ts
function isAsyncGenerator(value) {
return value != null && typeof value === "object" && Symbol.asyncIterator in value;
}
// src/util/is-generator.ts
function isGenerator(value) {
return value != null && typeof value === "object" && Symbol.iterator in value;
}
// src/util/constants.ts
var HANGING_STREAM_WARNING_TIME_MS = 15 * 1e3;
// src/streamable-ui/create-suspended-chunk.tsx
import { Suspense } from "react";
import { Fragment, jsx as jsx2, jsxs } from "react/jsx-runtime";
var R = [
async ({
c: current,
n: next
}) => {
const chunk = await next;
if (chunk.done) {
return chunk.value;
}
if (chunk.append) {
return /* @__PURE__ */ jsxs(Fragment, { children: [
current,
/* @__PURE__ */ jsx2(Suspense, { fallback: chunk.value, children: /* @__PURE__ */ jsx2(R, { c: chunk.value, n: chunk.next }) })
] });
}
return /* @__PURE__ */ jsx2(Suspense, { fallback: chunk.value, children: /* @__PURE__ */ jsx2(R, { c: chunk.value, n: chunk.next }) });
}
][0];
function createSuspendedChunk(initialValue) {
const { promise, resolve, reject } = createResolvablePromise();
return {
row: /* @__PURE__ */ jsx2(Suspense, { fallback: initialValue, children: /* @__PURE__ */ jsx2(R, { c: initialValue, n: promise }) }),
resolve,
reject
};
}
// src/streamable-ui/create-streamable-ui.tsx
function createStreamableUI(initialValue) {
let currentValue = initialValue;
let closed = false;
let { row, resolve, reject } = createSuspendedChunk(initialValue);
function assertStream(method) {
if (closed) {
throw new Error(method + ": UI stream is already closed.");
}
}
let warningTimeout;
function warnUnclosedStream() {
if (process.env.NODE_ENV === "development") {
if (warningTimeout) {
clearTimeout(warningTimeout);
}
warningTimeout = setTimeout(() => {
console.warn(
"The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`."
);
}, HANGING_STREAM_WARNING_TIME_MS);
}
}
warnUnclosedStream();
const streamable = {
value: row,
update(value) {
assertStream(".update()");
if (value === currentValue) {
warnUnclosedStream();
return streamable;
}
const resolvable = createResolvablePromise();
currentValue = value;
resolve({ value: currentValue, done: false, next: resolvable.promise });
resolve = resolvable.resolve;
reject = resolvable.reject;
warnUnclosedStream();
return streamable;
},
append(value) {
assertStream(".append()");
const resolvable = createResolvablePromise();
currentValue = value;
resolve({ value, done: false, append: true, next: resolvable.promise });
resolve = resolvable.resolve;
reject = resolvable.reject;
warnUnclosedStream();
return streamable;
},
error(error) {
assertStream(".error()");
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
reject(error);
return streamable;
},
done(...args) {
assertStream(".done()");
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
if (args.length) {
resolve({ value: args[0], done: true });
return streamable;
}
resolve({ value: currentValue, done: true });
return streamable;
}
};
return streamable;
}
// src/stream-ui/stream-ui.tsx
var defaultTextRenderer = ({ content }) => content;
async function streamUI({
model,
tools,
toolChoice,
system,
prompt,
messages,
maxRetries,
abortSignal,
headers,
initial,
text,
providerOptions,
onFinish,
...settings
}) {
if (typeof model === "string") {
throw new Error(
"`model` cannot be a string in `streamUI`. Use the actual model instance instead."
);
}
if ("functions" in settings) {
throw new Error(
"`functions` is not supported in `streamUI`, use `tools` instead."
);
}
if ("provider" in settings) {
throw new Error(
"`provider` is no longer needed in `streamUI`. Use `model` instead."
);
}
if (tools) {
for (const [name, tool] of Object.entries(tools)) {
if ("render" in tool) {
throw new Error(
"Tool definition in `streamUI` should not have `render` property. Use `generate` instead. Found in tool: " + name
);
}
}
}
const ui = createStreamableUI(initial);
const textRender = text || defaultTextRenderer;
let finished;
let finishEvent = null;
async function render({
args,
renderer,
streamableUI,
isLastCall = false
}) {
if (!renderer)
return;
const renderFinished = createResolvablePromise();
finished = finished ? finished.then(() => renderFinished.promise) : renderFinished.promise;
const rendererResult = renderer(...args);
if (isAsyncGenerator(rendererResult) || isGenerator(rendererResult)) {
while (true) {
const { done, value } = await rendererResult.next();
const node = await value;
if (isLastCall && done) {
streamableUI.done(node);
} else {
streamableUI.update(node);
}
if (done)
break;
}
} else {
const node = await rendererResult;
if (isLastCall) {
streamableUI.done(node);
} else {
streamableUI.update(node);
}
}
renderFinished.resolve(void 0);
}
const { retry } = prepareRetries({ maxRetries, abortSignal });
const validatedPrompt = await standardizePrompt({
system,
prompt,
messages
});
const result = await retry(
async () => model.doStream({
...prepareCallSettings(settings),
...prepareToolsAndToolChoice({
tools,
toolChoice,
activeTools: void 0
}),
prompt: await convertToLanguageModelPrompt({
prompt: validatedPrompt,
supportedUrls: await model.supportedUrls,
download: void 0
}),
providerOptions,
abortSignal,
headers,
includeRawChunks: false
})
);
const [stream, forkedStream] = result.stream.tee();
(async () => {
try {
let content = "";
let hasToolCall = false;
let warnings;
const reader = forkedStream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done)
break;
switch (value.type) {
case "stream-start": {
warnings = value.warnings;
break;
}
case "text-delta": {
content += value.delta;
render({
renderer: textRender,
args: [{ content, done: false, delta: value.delta }],
streamableUI: ui
});
break;
}
case "tool-input-start":
case "tool-input-delta": {
hasToolCall = true;
break;
}
case "tool-call": {
const toolName = value.toolName;
if (!tools) {
throw new NoSuchToolError({ toolName });
}
const tool = tools[toolName];
if (!tool) {
throw new NoSuchToolError({
toolName,
availableTools: Object.keys(tools)
});
}
hasToolCall = true;
const parseResult = await safeParseJSON({
text: value.input,
schema: tool.inputSchema
});
if (parseResult.success === false) {
throw new InvalidToolInputError({
toolName,
toolInput: value.input,
cause: parseResult.error
});
}
render({
renderer: tool.generate,
args: [
parseResult.value,
{
toolName,
toolCallId: value.toolCallId
}
],
streamableUI: ui,
isLastCall: true
});
break;
}
case "error": {
throw value.error;
}
case "finish": {
finishEvent = {
finishReason: value.finishReason,
usage: value.usage,
warnings,
response: result.response
};
break;
}
}
}
if (!hasToolCall) {
render({
renderer: textRender,
args: [{ content, done: true }],
streamableUI: ui,
isLastCall: true
});
}
await finished;
if (finishEvent && onFinish) {
await onFinish({
...finishEvent,
value: ui.value
});
}
} catch (error) {
ui.error(error);
}
})();
return {
...result,
stream,
value: ui.value
};
}
// src/streamable-value/streamable-value.ts
var STREAMABLE_VALUE_TYPE = Symbol.for("ui.streamable.value");
// src/streamable-value/create-streamable-value.ts
var STREAMABLE_VALUE_INTERNAL_LOCK = Symbol("streamable.value.lock");
function createStreamableValue(initialValue) {
const isReadableStream = initialValue instanceof ReadableStream || typeof initialValue === "object" && initialValue !== null && "getReader" in initialValue && typeof initialValue.getReader === "function" && "locked" in initialValue && typeof initialValue.locked === "boolean";
if (!isReadableStream) {
return createStreamableValueImpl(initialValue);
}
const streamableValue = createStreamableValueImpl();
streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = true;
(async () => {
try {
const reader = initialValue.getReader();
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
}
streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false;
if (typeof value === "string") {
streamableValue.append(value);
} else {
streamableValue.update(value);
}
streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = true;
}
streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false;
streamableValue.done();
} catch (e) {
streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false;
streamableValue.error(e);
}
})();
return streamableValue;
}
function createStreamableValueImpl(initialValue) {
let closed = false;
let locked = false;
let resolvable = createResolvablePromise();
let currentValue = initialValue;
let currentError;
let currentPromise = resolvable.promise;
let currentPatchValue;
function assertStream(method) {
if (closed) {
throw new Error(method + ": Value stream is already closed.");
}
if (locked) {
throw new Error(
method + ": Value stream is locked and cannot be updated."
);
}
}
let warningTimeout;
function warnUnclosedStream() {
if (process.env.NODE_ENV === "development") {
if (warningTimeout) {
clearTimeout(warningTimeout);
}
warningTimeout = setTimeout(() => {
console.warn(
"The streamable value has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`."
);
}, HANGING_STREAM_WARNING_TIME_MS);
}
}
warnUnclosedStream();
function createWrapped(initialChunk) {
let init;
if (currentError !== void 0) {
init = { error: currentError };
} else {
if (currentPatchValue && !initialChunk) {
init = { diff: currentPatchValue };
} else {
init = { curr: currentValue };
}
}
if (currentPromise) {
init.next = currentPromise;
}
if (initialChunk) {
init.type = STREAMABLE_VALUE_TYPE;
}
return init;
}
function updateValueStates(value) {
currentPatchValue = void 0;
if (typeof value === "string") {
if (typeof currentValue === "string") {
if (value.startsWith(currentValue)) {
currentPatchValue = [0, value.slice(currentValue.length)];
}
}
}
currentValue = value;
}
const streamable = {
set [STREAMABLE_VALUE_INTERNAL_LOCK](state) {
locked = state;
},
get value() {
return createWrapped(true);
},
update(value) {
assertStream(".update()");
const resolvePrevious = resolvable.resolve;
resolvable = createResolvablePromise();
updateValueStates(value);
currentPromise = resolvable.promise;
resolvePrevious(createWrapped());
warnUnclosedStream();
return streamable;
},
append(value) {
assertStream(".append()");
if (typeof currentValue !== "string" && typeof currentValue !== "undefined") {
throw new Error(
`.append(): The current value is not a string. Received: ${typeof currentValue}`
);
}
if (typeof value !== "string") {
throw new Error(
`.append(): The value is not a string. Received: ${typeof value}`
);
}
const resolvePrevious = resolvable.resolve;
resolvable = createResolvablePromise();
if (typeof currentValue === "string") {
currentPatchValue = [0, value];
currentValue = currentValue + value;
} else {
currentPatchValue = void 0;
currentValue = value;
}
currentPromise = resolvable.promise;
resolvePrevious(createWrapped());
warnUnclosedStream();
return streamable;
},
error(error) {
assertStream(".error()");
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
currentError = error;
currentPromise = void 0;
resolvable.resolve({ error });
return streamable;
},
done(...args) {
assertStream(".done()");
if (warningTimeout) {
clearTimeout(warningTimeout);
}
closed = true;
currentPromise = void 0;
if (args.length) {
updateValueStates(args[0]);
resolvable.resolve(createWrapped());
return streamable;
}
resolvable.resolve({});
return streamable;
}
};
return streamable;
}
export {
createAI,
createStreamableUI,
createStreamableValue,
getAIState,
getMutableAIState,
streamUI
};
//# sourceMappingURL=rsc-server.mjs.map