@convex-dev/agent
Version:
A agent component for Convex.
294 lines • 9.92 kB
JavaScript
import { smoothStream, } from "ai";
import { v } from "convex/values";
import { vMessageDoc, vPaginationResult, vStreamDelta, vStreamMessage, } from "../validators.js";
export const vStreamMessagesReturnValue = v.object({
...vPaginationResult(vMessageDoc).fields,
streams: v.optional(v.union(v.object({ kind: v.literal("list"), messages: v.array(vStreamMessage) }), v.object({ kind: v.literal("deltas"), deltas: v.array(vStreamDelta) }))),
});
/**
* A function that handles fetching stream deltas, used with the React hooks
* `useThreadMessages` or `useStreamingThreadMessages`.
* @param ctx A ctx object from a query, mutation, or action.
* @param component The agent component, usually `components.agent`.
* @param args.threadId The thread to sync streams for.
* @param args.streamArgs The stream arguments with per-stream cursors.
* @returns The deltas for each stream from their existing cursor.
*/
export async function syncStreams(ctx, component, { threadId, streamArgs, includeStatuses, }) {
if (!streamArgs)
return undefined;
if (streamArgs.kind === "list") {
return {
kind: "list",
messages: await listStreams(ctx, component, {
threadId,
startOrder: streamArgs.startOrder,
includeStatuses,
}),
};
}
else {
return {
kind: "deltas",
deltas: await ctx.runQuery(component.streams.listDeltas, {
threadId,
cursors: streamArgs.cursors,
}),
};
}
}
export async function abortStream(ctx, component, args) {
if ("streamId" in args) {
return await ctx.runMutation(component.streams.abort, {
reason: args.reason,
streamId: args.streamId,
});
}
else {
return await ctx.runMutation(component.streams.abortByOrder, {
reason: args.reason,
threadId: args.threadId,
order: args.order,
});
}
}
/**
* List the streaming messages for a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param component The agent component, usually `components.agent`.
* @param args.threadId The thread to list streams for.
* @param args.startOrder The order of the messages in the thread to start listing from.
* @param args.includeStatuses The statuses to include in the list.
* @returns The streams for the thread.
*/
export async function listStreams(ctx, component, { threadId, startOrder, includeStatuses, }) {
return ctx.runQuery(component.streams.list, {
threadId,
startOrder,
statuses: includeStatuses,
});
}
export const DEFAULT_STREAMING_OPTIONS = {
// This chunks by sentences / clauses. Punctuation followed by whitespace.
chunking: /[\p{P}\s]/u,
throttleMs: 250,
returnImmediately: false,
};
/**
*
* @param options The options passed to `agent.streamText` to decide whether to
* save deltas while streaming.
* @param existing The transforms passed to `agent.streamText` to merge with.
* @returns The merged transforms to pass to the underlying `streamText` call.
*/
export function mergeTransforms(options, existing) {
if (!options) {
return existing;
}
const chunking = typeof options === "boolean"
? DEFAULT_STREAMING_OPTIONS.chunking
: options.chunking;
const transforms = Array.isArray(existing)
? existing
: existing
? [existing]
: [];
transforms.push(smoothStream({ delayInMs: null, chunking }));
return transforms;
}
/**
* DeltaStreamer can be used to save a stream of "parts" by writing
* batches of them in "deltas" to the database so clients can subscribe
* (using the syncStreams utility and client hooks) and re-hydrate the stream.
* You can optionally compress the parts, e.g. concatenating text deltas, to
* optimize the data in transit.
*/
export class DeltaStreamer {
component;
ctx;
metadata;
streamId;
config;
#nextParts = [];
#latestWrite = 0;
#ongoingWrite;
#cursor = 0;
abortController;
constructor(component, ctx, config, metadata) {
this.component = component;
this.ctx = ctx;
this.metadata = metadata;
this.config = {
throttleMs: config.throttleMs ?? DEFAULT_STREAMING_OPTIONS.throttleMs,
onAsyncAbort: config.onAsyncAbort,
compress: config.compress,
};
this.#nextParts = [];
this.abortController = new AbortController();
if (config.abortSignal) {
config.abortSignal.addEventListener("abort", async () => {
if (this.abortController.signal.aborted) {
return;
}
if (this.streamId) {
this.abortController.abort();
await this.#ongoingWrite;
await this.ctx.runMutation(this.component.streams.abort, {
streamId: this.streamId,
reason: "abortSignal",
});
}
});
}
}
// Avoid race conditions by only creating once
#creatingStreamIdPromise;
async getStreamId() {
if (this.streamId) {
return this.streamId;
}
if (this.#creatingStreamIdPromise) {
return this.#creatingStreamIdPromise;
}
this.#creatingStreamIdPromise = this.ctx.runMutation(this.component.streams.create, this.metadata);
this.streamId = await this.#creatingStreamIdPromise;
}
async addParts(parts) {
if (this.abortController.signal.aborted) {
return;
}
await this.getStreamId();
this.#nextParts.push(...parts);
if (!this.#ongoingWrite &&
Date.now() - this.#latestWrite >= this.config.throttleMs) {
this.#ongoingWrite = this.#sendDelta();
}
}
async consumeStream(stream) {
for await (const chunk of stream) {
await this.addParts([chunk]);
}
await this.finish();
}
async #sendDelta() {
if (this.abortController.signal.aborted) {
return;
}
const delta = this.#createDelta();
if (!delta) {
return;
}
this.#latestWrite = Date.now();
try {
const success = await this.ctx.runMutation(this.component.streams.addDelta, delta);
if (!success) {
await this.config.onAsyncAbort("async abort");
this.abortController.abort();
return;
}
}
catch (e) {
await this.config.onAsyncAbort(e instanceof Error ? e.message : "unknown error");
this.abortController.abort();
throw e;
}
// Now that we've sent the delta, check if we need to send another one.
if (this.#nextParts.length > 0 &&
Date.now() - this.#latestWrite >= this.config.throttleMs) {
// We send again immediately with the accumulated deltas.
this.#ongoingWrite = this.#sendDelta();
}
else {
this.#ongoingWrite = undefined;
}
}
#createDelta() {
if (this.#nextParts.length === 0) {
return undefined;
}
const start = this.#cursor;
const end = start + this.#nextParts.length;
this.#cursor = end;
const parts = this.config.compress
? this.config.compress(this.#nextParts)
: this.#nextParts;
this.#nextParts = [];
if (!this.streamId) {
throw new Error("Creating a delta before the stream is created");
}
return { streamId: this.streamId, start, end, parts };
}
async finish() {
if (!this.streamId) {
return;
}
await this.#ongoingWrite;
await this.#sendDelta();
await this.ctx.runMutation(this.component.streams.finish, {
streamId: this.streamId,
});
}
async fail(reason) {
if (this.abortController.signal.aborted) {
return;
}
this.abortController.abort();
if (!this.streamId) {
return;
}
await this.#ongoingWrite;
await this.ctx.runMutation(this.component.streams.abort, {
streamId: this.streamId,
reason,
});
}
}
/**
* Compressing parts when streaming to save bandwidth in deltas.
*/
export function compressUIMessageChunks(parts) {
const compressed = [];
for (const part of parts) {
const last = compressed.at(-1);
if (part.type === "text-delta" || part.type === "reasoning-delta") {
if (last?.type === part.type && part.id === last.id) {
last.delta += part.delta;
}
else {
compressed.push(part);
}
}
else {
compressed.push(part);
}
}
return compressed;
}
export function compressTextStreamParts(parts) {
const compressed = [];
for (const part of parts) {
const last = compressed.at(-1);
if (part.type === "text-delta" || part.type === "reasoning-delta") {
if (last?.type === part.type && part.id === last.id) {
last.text += part.text;
}
else {
compressed.push(part);
}
}
else {
if (part.type === "file") {
compressed.push({
type: "file",
file: {
...part.file,
uint8Array: undefined,
},
});
}
compressed.push(part);
}
}
return compressed;
}
//# sourceMappingURL=streaming.js.map