UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

294 lines 9.92 kB
import { smoothStream, } from "ai"; import { v } from "convex/values"; import { vMessageDoc, vPaginationResult, vStreamDelta, vStreamMessage, } from "../validators.js"; export const vStreamMessagesReturnValue = v.object({ ...vPaginationResult(vMessageDoc).fields, streams: v.optional(v.union(v.object({ kind: v.literal("list"), messages: v.array(vStreamMessage) }), v.object({ kind: v.literal("deltas"), deltas: v.array(vStreamDelta) }))), }); /** * A function that handles fetching stream deltas, used with the React hooks * `useThreadMessages` or `useStreamingThreadMessages`. * @param ctx A ctx object from a query, mutation, or action. * @param component The agent component, usually `components.agent`. * @param args.threadId The thread to sync streams for. * @param args.streamArgs The stream arguments with per-stream cursors. * @returns The deltas for each stream from their existing cursor. */ export async function syncStreams(ctx, component, { threadId, streamArgs, includeStatuses, }) { if (!streamArgs) return undefined; if (streamArgs.kind === "list") { return { kind: "list", messages: await listStreams(ctx, component, { threadId, startOrder: streamArgs.startOrder, includeStatuses, }), }; } else { return { kind: "deltas", deltas: await ctx.runQuery(component.streams.listDeltas, { threadId, cursors: streamArgs.cursors, }), }; } } export async function abortStream(ctx, component, args) { if ("streamId" in args) { return await ctx.runMutation(component.streams.abort, { reason: args.reason, streamId: args.streamId, }); } else { return await ctx.runMutation(component.streams.abortByOrder, { reason: args.reason, threadId: args.threadId, order: args.order, }); } } /** * List the streaming messages for a thread. * @param ctx A ctx object from a query, mutation, or action. * @param component The agent component, usually `components.agent`. * @param args.threadId The thread to list streams for. * @param args.startOrder The order of the messages in the thread to start listing from. * @param args.includeStatuses The statuses to include in the list. * @returns The streams for the thread. */ export async function listStreams(ctx, component, { threadId, startOrder, includeStatuses, }) { return ctx.runQuery(component.streams.list, { threadId, startOrder, statuses: includeStatuses, }); } export const DEFAULT_STREAMING_OPTIONS = { // This chunks by sentences / clauses. Punctuation followed by whitespace. chunking: /[\p{P}\s]/u, throttleMs: 250, returnImmediately: false, }; /** * * @param options The options passed to `agent.streamText` to decide whether to * save deltas while streaming. * @param existing The transforms passed to `agent.streamText` to merge with. * @returns The merged transforms to pass to the underlying `streamText` call. */ export function mergeTransforms(options, existing) { if (!options) { return existing; } const chunking = typeof options === "boolean" ? DEFAULT_STREAMING_OPTIONS.chunking : options.chunking; const transforms = Array.isArray(existing) ? existing : existing ? [existing] : []; transforms.push(smoothStream({ delayInMs: null, chunking })); return transforms; } /** * DeltaStreamer can be used to save a stream of "parts" by writing * batches of them in "deltas" to the database so clients can subscribe * (using the syncStreams utility and client hooks) and re-hydrate the stream. * You can optionally compress the parts, e.g. concatenating text deltas, to * optimize the data in transit. */ export class DeltaStreamer { component; ctx; metadata; streamId; config; #nextParts = []; #latestWrite = 0; #ongoingWrite; #cursor = 0; abortController; constructor(component, ctx, config, metadata) { this.component = component; this.ctx = ctx; this.metadata = metadata; this.config = { throttleMs: config.throttleMs ?? DEFAULT_STREAMING_OPTIONS.throttleMs, onAsyncAbort: config.onAsyncAbort, compress: config.compress, }; this.#nextParts = []; this.abortController = new AbortController(); if (config.abortSignal) { config.abortSignal.addEventListener("abort", async () => { if (this.abortController.signal.aborted) { return; } if (this.streamId) { this.abortController.abort(); await this.#ongoingWrite; await this.ctx.runMutation(this.component.streams.abort, { streamId: this.streamId, reason: "abortSignal", }); } }); } } // Avoid race conditions by only creating once #creatingStreamIdPromise; async getStreamId() { if (this.streamId) { return this.streamId; } if (this.#creatingStreamIdPromise) { return this.#creatingStreamIdPromise; } this.#creatingStreamIdPromise = this.ctx.runMutation(this.component.streams.create, this.metadata); this.streamId = await this.#creatingStreamIdPromise; } async addParts(parts) { if (this.abortController.signal.aborted) { return; } await this.getStreamId(); this.#nextParts.push(...parts); if (!this.#ongoingWrite && Date.now() - this.#latestWrite >= this.config.throttleMs) { this.#ongoingWrite = this.#sendDelta(); } } async consumeStream(stream) { for await (const chunk of stream) { await this.addParts([chunk]); } await this.finish(); } async #sendDelta() { if (this.abortController.signal.aborted) { return; } const delta = this.#createDelta(); if (!delta) { return; } this.#latestWrite = Date.now(); try { const success = await this.ctx.runMutation(this.component.streams.addDelta, delta); if (!success) { await this.config.onAsyncAbort("async abort"); this.abortController.abort(); return; } } catch (e) { await this.config.onAsyncAbort(e instanceof Error ? e.message : "unknown error"); this.abortController.abort(); throw e; } // Now that we've sent the delta, check if we need to send another one. if (this.#nextParts.length > 0 && Date.now() - this.#latestWrite >= this.config.throttleMs) { // We send again immediately with the accumulated deltas. this.#ongoingWrite = this.#sendDelta(); } else { this.#ongoingWrite = undefined; } } #createDelta() { if (this.#nextParts.length === 0) { return undefined; } const start = this.#cursor; const end = start + this.#nextParts.length; this.#cursor = end; const parts = this.config.compress ? this.config.compress(this.#nextParts) : this.#nextParts; this.#nextParts = []; if (!this.streamId) { throw new Error("Creating a delta before the stream is created"); } return { streamId: this.streamId, start, end, parts }; } async finish() { if (!this.streamId) { return; } await this.#ongoingWrite; await this.#sendDelta(); await this.ctx.runMutation(this.component.streams.finish, { streamId: this.streamId, }); } async fail(reason) { if (this.abortController.signal.aborted) { return; } this.abortController.abort(); if (!this.streamId) { return; } await this.#ongoingWrite; await this.ctx.runMutation(this.component.streams.abort, { streamId: this.streamId, reason, }); } } /** * Compressing parts when streaming to save bandwidth in deltas. */ export function compressUIMessageChunks(parts) { const compressed = []; for (const part of parts) { const last = compressed.at(-1); if (part.type === "text-delta" || part.type === "reasoning-delta") { if (last?.type === part.type && part.id === last.id) { last.delta += part.delta; } else { compressed.push(part); } } else { compressed.push(part); } } return compressed; } export function compressTextStreamParts(parts) { const compressed = []; for (const part of parts) { const last = compressed.at(-1); if (part.type === "text-delta" || part.type === "reasoning-delta") { if (last?.type === part.type && part.id === last.id) { last.text += part.text; } else { compressed.push(part); } } else { if (part.type === "file") { compressed.push({ type: "file", file: { ...part.file, uint8Array: undefined, }, }); } compressed.push(part); } } return compressed; } //# sourceMappingURL=streaming.js.map