UNPKG

@assistant-ui/react

Version:

Typescript/React library for AI Chat

156 lines 4.79 kB
// src/runtimes/edge/createEdgeRuntimeAPI.ts import { EdgeRuntimeRequestOptionsSchema } from "./EdgeRuntimeRequestOptions.mjs"; import { toLanguageModelMessages } from "./converters/toLanguageModelMessages.mjs"; import { toLanguageModelTools } from "./converters/toLanguageModelTools.mjs"; import { toolResultStream } from "./streams/toolResultStream.mjs"; import { LanguageModelV1CallSettingsSchema } from "../../model-context/ModelContextTypes.mjs"; import { AssistantMessageAccumulator, DataStreamEncoder } from "assistant-stream"; import { LanguageModelV1StreamDecoder } from "assistant-stream/ai-sdk"; var getEdgeRuntimeStream = async ({ abortSignal, requestData: unsafeRequest, options: { model: modelOrCreator, system: serverSystem, tools: serverTools = {}, toolChoice, onFinish, ...unsafeSettings } }) => { const settings = LanguageModelV1CallSettingsSchema.parse(unsafeSettings); const lmServerTools = toLanguageModelTools(serverTools); const hasServerTools = Object.values(serverTools).some((v) => !!v.execute); const { system: clientSystem, tools: clientTools = [], messages, apiKey, baseUrl, modelName, ...callSettings } = EdgeRuntimeRequestOptionsSchema.parse(unsafeRequest); const systemMessages = []; if (serverSystem) systemMessages.push(serverSystem); if (clientSystem) systemMessages.push(clientSystem); const system = systemMessages.join("\n\n"); for (const clientTool of clientTools) { if (serverTools?.[clientTool.name]) { throw new Error( `Tool ${clientTool.name} was defined in both the client and server tools. This is not allowed.` ); } } const model = typeof modelOrCreator === "function" ? await modelOrCreator({ apiKey, baseUrl, modelName }) : modelOrCreator; let stream; const streamResult = await streamMessage({ ...settings, ...callSettings, model, abortSignal, ...!!system ? { system } : void 0, messages, tools: lmServerTools.concat(clientTools), ...toolChoice ? { toolChoice } : void 0 }); stream = streamResult.stream.pipeThrough(new LanguageModelV1StreamDecoder()); const canExecuteTools = hasServerTools && toolChoice?.type !== "none"; if (canExecuteTools) { stream = stream.pipeThrough(toolResultStream(serverTools, abortSignal)); } if (canExecuteTools || onFinish) { const tees = stream.tee(); stream = tees[0]; let serverStream = tees[1]; if (onFinish) { let lastChunk; serverStream.pipeThrough(new AssistantMessageAccumulator()).pipeTo( new WritableStream({ write(chunk) { lastChunk = chunk; }, close() { if (!lastChunk?.status || lastChunk.status.type === "running") return; const resultingMessages = [ ...messages, { id: "DEFAULT", createdAt: /* @__PURE__ */ new Date(), role: "assistant", content: lastChunk.content, status: lastChunk.status, metadata: lastChunk.metadata } ]; onFinish({ messages: resultingMessages, metadata: { steps: lastChunk.metadata.steps } }); }, abort(e) { console.error("Server stream processing error:", e); } }) ); } } return stream; }; var getEdgeRuntimeResponse = async (options) => { const stream = await getEdgeRuntimeStream(options); return new Response(stream.pipeThrough(new DataStreamEncoder()), { headers: { "Content-Type": "text/plain; charset=utf-8", "x-vercel-ai-data-stream": "v1" } }); }; var createEdgeRuntimeAPI = (options) => ({ POST: async (request) => getEdgeRuntimeResponse({ abortSignal: request.signal, requestData: await request.json(), options }) }); async function streamMessage({ model, system, messages, tools, toolChoice, ...options }) { return model.doStream({ inputFormat: "messages", mode: { type: "regular", ...tools ? { tools } : void 0, ...toolChoice ? { toolChoice } : void 0 }, prompt: convertToLanguageModelPrompt(system, messages), ...options }); } function convertToLanguageModelPrompt(system, messages) { const languageModelMessages = []; if (system != null) { languageModelMessages.push({ role: "system", content: system }); } languageModelMessages.push(...toLanguageModelMessages(messages)); return languageModelMessages; } export { convertToLanguageModelPrompt, createEdgeRuntimeAPI, getEdgeRuntimeResponse, getEdgeRuntimeStream }; //# sourceMappingURL=createEdgeRuntimeAPI.mjs.map