UNPKG

@ai-sdk/rsc

Version:

[React Server Components](https://react.dev/reference/rsc/server-components) for the [AI SDK](https://ai-sdk.dev/docs):

739 lines (726 loc) 20.4 kB
// src/ai-state.tsx import * as jsondiffpatch from "jsondiffpatch"; import { AsyncLocalStorage } from "async_hooks"; // src/util/create-resolvable-promise.ts function createResolvablePromise() { let resolve; let reject; const promise = new Promise((res, rej) => { resolve = res; reject = rej; }); return { promise, resolve, reject }; } // src/util/is-function.ts var isFunction = (value) => typeof value === "function"; // src/ai-state.tsx var asyncAIStateStorage = new AsyncLocalStorage(); function getAIStateStoreOrThrow(message) { const store = asyncAIStateStorage.getStore(); if (!store) { throw new Error(message); } return store; } function withAIState({ state, options }, fn) { return asyncAIStateStorage.run( { currentState: JSON.parse(JSON.stringify(state)), // deep clone object originalState: state, sealed: false, options }, fn ); } function getAIStateDeltaPromise() { const store = getAIStateStoreOrThrow("Internal error occurred."); return store.mutationDeltaPromise; } function sealMutableAIState() { const store = getAIStateStoreOrThrow("Internal error occurred."); store.sealed = true; } function getAIState(...args) { const store = getAIStateStoreOrThrow( "`getAIState` must be called within an AI Action." ); if (args.length > 0) { const key = args[0]; if (typeof store.currentState !== "object") { throw new Error( `You can't get the "${String( key )}" field from the AI state because it's not an object.` ); } return store.currentState[key]; } return store.currentState; } function getMutableAIState(...args) { const store = getAIStateStoreOrThrow( "`getMutableAIState` must be called within an AI Action." ); if (store.sealed) { throw new Error( "`getMutableAIState` must be called before returning from an AI Action. Please move it to the top level of the Action's function body." ); } if (!store.mutationDeltaPromise) { const { promise, resolve } = createResolvablePromise(); store.mutationDeltaPromise = promise; store.mutationDeltaResolve = resolve; } function doUpdate(newState, done) { var _a, _b; if (args.length > 0) { if (typeof store.currentState !== "object") { const key = args[0]; throw new Error( `You can't modify the "${String( key )}" field of the AI state because it's not an object.` ); } } if (isFunction(newState)) { if (args.length > 0) { store.currentState[args[0]] = newState(store.currentState[args[0]]); } else { store.currentState = newState(store.currentState); } } else { if (args.length > 0) { store.currentState[args[0]] = newState; } else { store.currentState = newState; } } (_b = (_a = store.options).onSetAIState) == null ? void 0 : _b.call(_a, { key: args.length > 0 ? args[0] : void 0, state: store.currentState, done }); } const mutableState = { get: () => { if (args.length > 0) { const key = args[0]; if (typeof store.currentState !== "object") { throw new Error( `You can't get the "${String( key )}" field from the AI state because it's not an object.` ); } return store.currentState[key]; } return store.currentState; }, update: function update(newAIState) { doUpdate(newAIState, false); }, done: function done(...doneArgs) { if (doneArgs.length > 0) { doUpdate(doneArgs[0], true); } const delta = jsondiffpatch.diff(store.originalState, store.currentState); store.mutationDeltaResolve(delta); } }; return mutableState; } // src/provider.tsx import * as React from "react"; import { InternalAIProvider } from "./rsc-shared.mjs"; import { jsx } from "react/jsx-runtime"; async function innerAction({ action, options }, state, ...args) { "use server"; return await withAIState( { state, options }, async () => { const result = await action(...args); sealMutableAIState(); return [getAIStateDeltaPromise(), result]; } ); } function wrapAction(action, options) { return innerAction.bind(null, { action, options }); } function createAI({ actions, initialAIState, initialUIState, onSetAIState, onGetUIState }) { const wrappedActions = {}; for (const name in actions) { wrappedActions[name] = wrapAction(actions[name], { onSetAIState }); } const wrappedSyncUIState = onGetUIState ? wrapAction(onGetUIState, {}) : void 0; const AI = async (props) => { var _a, _b; if ("useState" in React) { throw new Error( "This component can only be used inside Server Components." ); } let uiState = (_a = props.initialUIState) != null ? _a : initialUIState; let aiState = (_b = props.initialAIState) != null ? _b : initialAIState; let aiStateDelta = void 0; if (wrappedSyncUIState) { const [newAIStateDelta, newUIState] = await wrappedSyncUIState(aiState); if (newUIState !== void 0) { aiStateDelta = newAIStateDelta; uiState = newUIState; } } return /* @__PURE__ */ jsx( InternalAIProvider, { wrappedActions, wrappedSyncUIState, initialUIState: uiState, initialAIState: aiState, initialAIStatePatch: aiStateDelta, children: props.children } ); }; return AI; } // src/stream-ui/stream-ui.tsx import { safeParseJSON } from "@ai-sdk/provider-utils"; import { InvalidToolInputError, NoSuchToolError } from "ai"; import { standardizePrompt, prepareToolsAndToolChoice, prepareRetries, prepareCallSettings, convertToLanguageModelPrompt } from "ai/internal"; // src/util/is-async-generator.ts function isAsyncGenerator(value) { return value != null && typeof value === "object" && Symbol.asyncIterator in value; } // src/util/is-generator.ts function isGenerator(value) { return value != null && typeof value === "object" && Symbol.iterator in value; } // src/util/constants.ts var HANGING_STREAM_WARNING_TIME_MS = 15 * 1e3; // src/streamable-ui/create-suspended-chunk.tsx import { Suspense } from "react"; import { Fragment, jsx as jsx2, jsxs } from "react/jsx-runtime"; var R = [ async ({ c: current, n: next }) => { const chunk = await next; if (chunk.done) { return chunk.value; } if (chunk.append) { return /* @__PURE__ */ jsxs(Fragment, { children: [ current, /* @__PURE__ */ jsx2(Suspense, { fallback: chunk.value, children: /* @__PURE__ */ jsx2(R, { c: chunk.value, n: chunk.next }) }) ] }); } return /* @__PURE__ */ jsx2(Suspense, { fallback: chunk.value, children: /* @__PURE__ */ jsx2(R, { c: chunk.value, n: chunk.next }) }); } ][0]; function createSuspendedChunk(initialValue) { const { promise, resolve, reject } = createResolvablePromise(); return { row: /* @__PURE__ */ jsx2(Suspense, { fallback: initialValue, children: /* @__PURE__ */ jsx2(R, { c: initialValue, n: promise }) }), resolve, reject }; } // src/streamable-ui/create-streamable-ui.tsx function createStreamableUI(initialValue) { let currentValue = initialValue; let closed = false; let { row, resolve, reject } = createSuspendedChunk(initialValue); function assertStream(method) { if (closed) { throw new Error(method + ": UI stream is already closed."); } } let warningTimeout; function warnUnclosedStream() { if (process.env.NODE_ENV === "development") { if (warningTimeout) { clearTimeout(warningTimeout); } warningTimeout = setTimeout(() => { console.warn( "The streamable UI has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`." ); }, HANGING_STREAM_WARNING_TIME_MS); } } warnUnclosedStream(); const streamable = { value: row, update(value) { assertStream(".update()"); if (value === currentValue) { warnUnclosedStream(); return streamable; } const resolvable = createResolvablePromise(); currentValue = value; resolve({ value: currentValue, done: false, next: resolvable.promise }); resolve = resolvable.resolve; reject = resolvable.reject; warnUnclosedStream(); return streamable; }, append(value) { assertStream(".append()"); const resolvable = createResolvablePromise(); currentValue = value; resolve({ value, done: false, append: true, next: resolvable.promise }); resolve = resolvable.resolve; reject = resolvable.reject; warnUnclosedStream(); return streamable; }, error(error) { assertStream(".error()"); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; reject(error); return streamable; }, done(...args) { assertStream(".done()"); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; if (args.length) { resolve({ value: args[0], done: true }); return streamable; } resolve({ value: currentValue, done: true }); return streamable; } }; return streamable; } // src/stream-ui/stream-ui.tsx var defaultTextRenderer = ({ content }) => content; async function streamUI({ model, tools, toolChoice, system, prompt, messages, maxRetries, abortSignal, headers, initial, text, providerOptions, onFinish, ...settings }) { if (typeof model === "string") { throw new Error( "`model` cannot be a string in `streamUI`. Use the actual model instance instead." ); } if ("functions" in settings) { throw new Error( "`functions` is not supported in `streamUI`, use `tools` instead." ); } if ("provider" in settings) { throw new Error( "`provider` is no longer needed in `streamUI`. Use `model` instead." ); } if (tools) { for (const [name, tool] of Object.entries(tools)) { if ("render" in tool) { throw new Error( "Tool definition in `streamUI` should not have `render` property. Use `generate` instead. Found in tool: " + name ); } } } const ui = createStreamableUI(initial); const textRender = text || defaultTextRenderer; let finished; let finishEvent = null; async function render({ args, renderer, streamableUI, isLastCall = false }) { if (!renderer) return; const renderFinished = createResolvablePromise(); finished = finished ? finished.then(() => renderFinished.promise) : renderFinished.promise; const rendererResult = renderer(...args); if (isAsyncGenerator(rendererResult) || isGenerator(rendererResult)) { while (true) { const { done, value } = await rendererResult.next(); const node = await value; if (isLastCall && done) { streamableUI.done(node); } else { streamableUI.update(node); } if (done) break; } } else { const node = await rendererResult; if (isLastCall) { streamableUI.done(node); } else { streamableUI.update(node); } } renderFinished.resolve(void 0); } const { retry } = prepareRetries({ maxRetries, abortSignal }); const validatedPrompt = await standardizePrompt({ system, prompt, messages }); const result = await retry( async () => model.doStream({ ...prepareCallSettings(settings), ...prepareToolsAndToolChoice({ tools, toolChoice, activeTools: void 0 }), prompt: await convertToLanguageModelPrompt({ prompt: validatedPrompt, supportedUrls: await model.supportedUrls, download: void 0 }), providerOptions, abortSignal, headers, includeRawChunks: false }) ); const [stream, forkedStream] = result.stream.tee(); (async () => { try { let content = ""; let hasToolCall = false; let warnings; const reader = forkedStream.getReader(); while (true) { const { done, value } = await reader.read(); if (done) break; switch (value.type) { case "stream-start": { warnings = value.warnings; break; } case "text-delta": { content += value.delta; render({ renderer: textRender, args: [{ content, done: false, delta: value.delta }], streamableUI: ui }); break; } case "tool-input-start": case "tool-input-delta": { hasToolCall = true; break; } case "tool-call": { const toolName = value.toolName; if (!tools) { throw new NoSuchToolError({ toolName }); } const tool = tools[toolName]; if (!tool) { throw new NoSuchToolError({ toolName, availableTools: Object.keys(tools) }); } hasToolCall = true; const parseResult = await safeParseJSON({ text: value.input, schema: tool.inputSchema }); if (parseResult.success === false) { throw new InvalidToolInputError({ toolName, toolInput: value.input, cause: parseResult.error }); } render({ renderer: tool.generate, args: [ parseResult.value, { toolName, toolCallId: value.toolCallId } ], streamableUI: ui, isLastCall: true }); break; } case "error": { throw value.error; } case "finish": { finishEvent = { finishReason: value.finishReason, usage: value.usage, warnings, response: result.response }; break; } } } if (!hasToolCall) { render({ renderer: textRender, args: [{ content, done: true }], streamableUI: ui, isLastCall: true }); } await finished; if (finishEvent && onFinish) { await onFinish({ ...finishEvent, value: ui.value }); } } catch (error) { ui.error(error); } })(); return { ...result, stream, value: ui.value }; } // src/streamable-value/streamable-value.ts var STREAMABLE_VALUE_TYPE = Symbol.for("ui.streamable.value"); // src/streamable-value/create-streamable-value.ts var STREAMABLE_VALUE_INTERNAL_LOCK = Symbol("streamable.value.lock"); function createStreamableValue(initialValue) { const isReadableStream = initialValue instanceof ReadableStream || typeof initialValue === "object" && initialValue !== null && "getReader" in initialValue && typeof initialValue.getReader === "function" && "locked" in initialValue && typeof initialValue.locked === "boolean"; if (!isReadableStream) { return createStreamableValueImpl(initialValue); } const streamableValue = createStreamableValueImpl(); streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = true; (async () => { try { const reader = initialValue.getReader(); while (true) { const { value, done } = await reader.read(); if (done) { break; } streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false; if (typeof value === "string") { streamableValue.append(value); } else { streamableValue.update(value); } streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = true; } streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false; streamableValue.done(); } catch (e) { streamableValue[STREAMABLE_VALUE_INTERNAL_LOCK] = false; streamableValue.error(e); } })(); return streamableValue; } function createStreamableValueImpl(initialValue) { let closed = false; let locked = false; let resolvable = createResolvablePromise(); let currentValue = initialValue; let currentError; let currentPromise = resolvable.promise; let currentPatchValue; function assertStream(method) { if (closed) { throw new Error(method + ": Value stream is already closed."); } if (locked) { throw new Error( method + ": Value stream is locked and cannot be updated." ); } } let warningTimeout; function warnUnclosedStream() { if (process.env.NODE_ENV === "development") { if (warningTimeout) { clearTimeout(warningTimeout); } warningTimeout = setTimeout(() => { console.warn( "The streamable value has been slow to update. This may be a bug or a performance issue or you forgot to call `.done()`." ); }, HANGING_STREAM_WARNING_TIME_MS); } } warnUnclosedStream(); function createWrapped(initialChunk) { let init; if (currentError !== void 0) { init = { error: currentError }; } else { if (currentPatchValue && !initialChunk) { init = { diff: currentPatchValue }; } else { init = { curr: currentValue }; } } if (currentPromise) { init.next = currentPromise; } if (initialChunk) { init.type = STREAMABLE_VALUE_TYPE; } return init; } function updateValueStates(value) { currentPatchValue = void 0; if (typeof value === "string") { if (typeof currentValue === "string") { if (value.startsWith(currentValue)) { currentPatchValue = [0, value.slice(currentValue.length)]; } } } currentValue = value; } const streamable = { set [STREAMABLE_VALUE_INTERNAL_LOCK](state) { locked = state; }, get value() { return createWrapped(true); }, update(value) { assertStream(".update()"); const resolvePrevious = resolvable.resolve; resolvable = createResolvablePromise(); updateValueStates(value); currentPromise = resolvable.promise; resolvePrevious(createWrapped()); warnUnclosedStream(); return streamable; }, append(value) { assertStream(".append()"); if (typeof currentValue !== "string" && typeof currentValue !== "undefined") { throw new Error( `.append(): The current value is not a string. Received: ${typeof currentValue}` ); } if (typeof value !== "string") { throw new Error( `.append(): The value is not a string. Received: ${typeof value}` ); } const resolvePrevious = resolvable.resolve; resolvable = createResolvablePromise(); if (typeof currentValue === "string") { currentPatchValue = [0, value]; currentValue = currentValue + value; } else { currentPatchValue = void 0; currentValue = value; } currentPromise = resolvable.promise; resolvePrevious(createWrapped()); warnUnclosedStream(); return streamable; }, error(error) { assertStream(".error()"); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; currentError = error; currentPromise = void 0; resolvable.resolve({ error }); return streamable; }, done(...args) { assertStream(".done()"); if (warningTimeout) { clearTimeout(warningTimeout); } closed = true; currentPromise = void 0; if (args.length) { updateValueStates(args[0]); resolvable.resolve(createWrapped()); return streamable; } resolvable.resolve({}); return streamable; } }; return streamable; } export { createAI, createStreamableUI, createStreamableValue, getAIState, getMutableAIState, streamUI }; //# sourceMappingURL=rsc-server.mjs.map