UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

428 lines (409 loc) 12.1 kB
import { smoothStream, type AsyncIterableStream, type ChunkDetector, type StreamTextTransform, type TextStreamPart, type ToolSet, type UIMessageChunk, } from "ai"; import { v } from "convex/values"; import { vMessageDoc, vPaginationResult, vStreamDelta, vStreamMessage, type ProviderOptions, type StreamArgs, type StreamDelta, type StreamMessage, } from "../validators.js"; import type { ActionCtx, AgentComponent, MutationCtx, QueryCtx, SyncStreamsReturnValue, } from "./types.js"; export const vStreamMessagesReturnValue = v.object({ ...vPaginationResult(vMessageDoc).fields, streams: v.optional( v.union( v.object({ kind: v.literal("list"), messages: v.array(vStreamMessage) }), v.object({ kind: v.literal("deltas"), deltas: v.array(vStreamDelta) }), ), ), }); /** * A function that handles fetching stream deltas, used with the React hooks * `useThreadMessages` or `useStreamingThreadMessages`. * @param ctx A ctx object from a query, mutation, or action. * @param component The agent component, usually `components.agent`. * @param args.threadId The thread to sync streams for. * @param args.streamArgs The stream arguments with per-stream cursors. * @returns The deltas for each stream from their existing cursor. */ export async function syncStreams( ctx: QueryCtx | MutationCtx | ActionCtx, component: AgentComponent, { threadId, streamArgs, includeStatuses, }: { threadId: string; streamArgs?: StreamArgs | undefined; // By default, only streaming messages are included. includeStatuses?: ("streaming" | "finished" | "aborted")[]; }, ): Promise<SyncStreamsReturnValue | undefined> { if (!streamArgs) return undefined; if (streamArgs.kind === "list") { return { kind: "list", messages: await listStreams(ctx, component, { threadId, startOrder: streamArgs.startOrder, includeStatuses, }), }; } else { return { kind: "deltas", deltas: await ctx.runQuery(component.streams.listDeltas, { threadId, cursors: streamArgs.cursors, }), }; } } export async function abortStream( ctx: MutationCtx | ActionCtx, component: AgentComponent, args: { reason: string } & ( | { streamId: string } | { threadId: string; order: number } ), ): Promise<boolean> { if ("streamId" in args) { return await ctx.runMutation(component.streams.abort, { reason: args.reason, streamId: args.streamId, }); } else { return await ctx.runMutation(component.streams.abortByOrder, { reason: args.reason, threadId: args.threadId, order: args.order, }); } } /** * List the streaming messages for a thread. * @param ctx A ctx object from a query, mutation, or action. * @param component The agent component, usually `components.agent`. * @param args.threadId The thread to list streams for. * @param args.startOrder The order of the messages in the thread to start listing from. * @param args.includeStatuses The statuses to include in the list. * @returns The streams for the thread. */ export async function listStreams( ctx: QueryCtx | MutationCtx | ActionCtx, component: AgentComponent, { threadId, startOrder, includeStatuses, }: { threadId: string; startOrder?: number; includeStatuses?: ("streaming" | "finished" | "aborted")[]; }, ): Promise<StreamMessage[]> { return ctx.runQuery(component.streams.list, { threadId, startOrder, statuses: includeStatuses, }); } export type StreamingOptions = { /** * The minimum granularity of deltas to save. * Note: this is not a guarantee that every delta will be exactly one line. * E.g. if "line" is specified, it won't save any deltas until it encounters * a newline character. * Defaults to a regex that chunks by punctuation followed by whitespace. */ chunking?: "word" | "line" | RegExp | ChunkDetector; /** * The minimum number of milliseconds to wait between saving deltas. * Defaults to 250. */ throttleMs?: number; /** * If set to true, this will return immediately, as it would if you weren't * saving the deltas. Otherwise, the call will "consume" the stream with * .consumeStream(), which waits for the stream to finish before returning. * * When saving deltas, you're often not interactin with the stream otherwise. */ returnImmediately?: boolean; }; export const DEFAULT_STREAMING_OPTIONS = { // This chunks by sentences / clauses. Punctuation followed by whitespace. chunking: /[\p{P}\s]/u, throttleMs: 250, returnImmediately: false, } satisfies StreamingOptions; /** * * @param options The options passed to `agent.streamText` to decide whether to * save deltas while streaming. * @param existing The transforms passed to `agent.streamText` to merge with. * @returns The merged transforms to pass to the underlying `streamText` call. */ export function mergeTransforms<TOOLS extends ToolSet>( options: { chunking?: StreamingOptions["chunking"] } | boolean | undefined, existing: | StreamTextTransform<TOOLS> | Array<StreamTextTransform<TOOLS>> | undefined, ) { if (!options) { return existing; } const chunking = typeof options === "boolean" ? DEFAULT_STREAMING_OPTIONS.chunking : options.chunking; const transforms = Array.isArray(existing) ? existing : existing ? [existing] : []; transforms.push(smoothStream({ delayInMs: null, chunking })); return transforms; } /** * DeltaStreamer can be used to save a stream of "parts" by writing * batches of them in "deltas" to the database so clients can subscribe * (using the syncStreams utility and client hooks) and re-hydrate the stream. * You can optionally compress the parts, e.g. concatenating text deltas, to * optimize the data in transit. */ export class DeltaStreamer<T> { streamId: string | undefined; public readonly config: { throttleMs: number; onAsyncAbort: (reason: string) => Promise<void>; compress: ((parts: T[]) => T[]) | null; }; #nextParts: T[] = []; #latestWrite: number = 0; #ongoingWrite: Promise<void> | undefined; #cursor: number = 0; public abortController: AbortController; constructor( public readonly component: AgentComponent, public readonly ctx: MutationCtx | ActionCtx, config: { throttleMs: number | undefined; onAsyncAbort: (reason: string) => Promise<void>; abortSignal: AbortSignal | undefined; compress: ((parts: T[]) => T[]) | null; }, public readonly metadata: { threadId: string; userId?: string; order: number; stepOrder: number; agentName?: string; model?: string; provider?: string; providerOptions?: ProviderOptions; format: "UIMessageChunk" | "TextStreamPart" | undefined; }, ) { this.config = { throttleMs: config.throttleMs ?? DEFAULT_STREAMING_OPTIONS.throttleMs, onAsyncAbort: config.onAsyncAbort, compress: config.compress, }; this.#nextParts = []; this.abortController = new AbortController(); if (config.abortSignal) { config.abortSignal.addEventListener("abort", async () => { if (this.abortController.signal.aborted) { return; } if (this.streamId) { this.abortController.abort(); await this.#ongoingWrite; await this.ctx.runMutation(this.component.streams.abort, { streamId: this.streamId, reason: "abortSignal", }); } }); } } // Avoid race conditions by only creating once #creatingStreamIdPromise: Promise<string> | undefined; public async getStreamId() { if (this.streamId) { return this.streamId; } if (this.#creatingStreamIdPromise) { return this.#creatingStreamIdPromise; } this.#creatingStreamIdPromise = this.ctx.runMutation( this.component.streams.create, this.metadata, ); this.streamId = await this.#creatingStreamIdPromise; } public async addParts(parts: T[]) { if (this.abortController.signal.aborted) { return; } await this.getStreamId(); this.#nextParts.push(...parts); if ( !this.#ongoingWrite && Date.now() - this.#latestWrite >= this.config.throttleMs ) { this.#ongoingWrite = this.#sendDelta(); } } public async consumeStream(stream: AsyncIterableStream<T>) { for await (const chunk of stream) { await this.addParts([chunk]); } await this.finish(); } async #sendDelta() { if (this.abortController.signal.aborted) { return; } const delta = this.#createDelta(); if (!delta) { return; } this.#latestWrite = Date.now(); try { const success = await this.ctx.runMutation( this.component.streams.addDelta, delta, ); if (!success) { await this.config.onAsyncAbort("async abort"); this.abortController.abort(); return; } } catch (e) { await this.config.onAsyncAbort( e instanceof Error ? e.message : "unknown error", ); this.abortController.abort(); throw e; } // Now that we've sent the delta, check if we need to send another one. if ( this.#nextParts.length > 0 && Date.now() - this.#latestWrite >= this.config.throttleMs ) { // We send again immediately with the accumulated deltas. this.#ongoingWrite = this.#sendDelta(); } else { this.#ongoingWrite = undefined; } } #createDelta(): StreamDelta | undefined { if (this.#nextParts.length === 0) { return undefined; } const start = this.#cursor; const end = start + this.#nextParts.length; this.#cursor = end; const parts = this.config.compress ? this.config.compress(this.#nextParts) : this.#nextParts; this.#nextParts = []; if (!this.streamId) { throw new Error("Creating a delta before the stream is created"); } return { streamId: this.streamId, start, end, parts }; } public async finish() { if (!this.streamId) { return; } await this.#ongoingWrite; await this.#sendDelta(); await this.ctx.runMutation(this.component.streams.finish, { streamId: this.streamId, }); } public async fail(reason: string) { if (this.abortController.signal.aborted) { return; } this.abortController.abort(); if (!this.streamId) { return; } await this.#ongoingWrite; await this.ctx.runMutation(this.component.streams.abort, { streamId: this.streamId, reason, }); } } /** * Compressing parts when streaming to save bandwidth in deltas. */ export function compressUIMessageChunks( parts: UIMessageChunk[], ): UIMessageChunk[] { const compressed: UIMessageChunk[] = []; for (const part of parts) { const last = compressed.at(-1); if (part.type === "text-delta" || part.type === "reasoning-delta") { if (last?.type === part.type && part.id === last.id) { last.delta += part.delta; } else { compressed.push(part); } } else { compressed.push(part); } } return compressed; } export function compressTextStreamParts( parts: TextStreamPart<ToolSet>[], ): TextStreamPart<ToolSet>[] { const compressed: TextStreamPart<ToolSet>[] = []; for (const part of parts) { const last = compressed.at(-1); if (part.type === "text-delta" || part.type === "reasoning-delta") { if (last?.type === part.type && part.id === last.id) { last.text += part.text; } else { compressed.push(part); } } else { if (part.type === "file") { compressed.push({ type: "file", file: { ...part.file, uint8Array: undefined as unknown as Uint8Array, }, }); } compressed.push(part); } } return compressed; }