@convex-dev/agent
Version:
A agent component for Convex.
311 lines (298 loc) • 8.78 kB
text/typescript
import {
type ChunkDetector,
smoothStream,
type StreamTextTransform,
type TextStreamPart,
type ToolSet,
} from "ai";
import type {
ProviderOptions,
StreamArgs,
StreamDelta,
StreamMessage,
} from "../validators.js";
import type {
AgentComponent,
RunActionCtx,
RunMutationCtx,
RunQueryCtx,
SyncStreamsReturnValue,
} from "./types.js";
import { omit } from "convex-helpers";
import { serializeTextStreamingPartsV5 } from "../parts.js";
/**
* A function that handles fetching stream deltas, used with the React hooks
* `useThreadMessages` or `useStreamingThreadMessages`.
* @param ctx A ctx object from a query, mutation, or action.
* @param component The agent component, usually `components.agent`.
* @param args.threadId The thread to sync streams for.
* @param args.streamArgs The stream arguments with per-stream cursors.
* @returns The deltas for each stream from their existing cursor.
*/
export async function syncStreams(
ctx: RunQueryCtx,
component: AgentComponent,
args: {
threadId: string;
streamArgs: StreamArgs | undefined;
// By default, only streaming messages are included.
includeStatuses?: ("streaming" | "finished" | "aborted")[];
},
): Promise<SyncStreamsReturnValue | undefined> {
if (!args.streamArgs) return undefined;
if (args.streamArgs.kind === "list") {
return {
kind: "list",
messages: await listStreams(ctx, component, {
threadId: args.threadId,
startOrder: args.streamArgs.startOrder,
includeStatuses: args.includeStatuses,
}),
};
} else {
return {
kind: "deltas",
deltas: await ctx.runQuery(component.streams.listDeltas, {
threadId: args.threadId,
cursors: args.streamArgs.cursors,
}),
};
}
}
export async function abortStream(
ctx: RunMutationCtx,
component: AgentComponent,
args: { reason: string } & (
| { streamId: string }
| { threadId: string; order: number }
),
): Promise<boolean> {
if ("streamId" in args) {
return await ctx.runMutation(component.streams.abort, {
reason: args.reason,
streamId: args.streamId,
});
} else {
return await ctx.runMutation(component.streams.abortByOrder, {
reason: args.reason,
threadId: args.threadId,
order: args.order,
});
}
}
/**
* List the streaming messages for a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param component The agent component, usually `components.agent`.
* @param args.threadId The thread to list streams for.
* @param args.startOrder The order of the messages in the thread to start listing from.
* @param args.includeStatuses The statuses to include in the list.
* @returns The streams for the thread.
*/
export async function listStreams(
ctx: RunQueryCtx,
component: AgentComponent,
{
threadId,
startOrder,
includeStatuses,
}: {
threadId: string;
startOrder?: number;
includeStatuses?: ("streaming" | "finished" | "aborted")[];
},
): Promise<StreamMessage[]> {
return ctx.runQuery(component.streams.list, {
threadId,
startOrder,
statuses: includeStatuses,
});
}
export type StreamingOptions = {
/**
* The minimum granularity of deltas to save.
* Note: this is not a guarantee that every delta will be exactly one line.
* E.g. if "line" is specified, it won't save any deltas until it encounters
* a newline character.
* Defaults to a regex that chunks by punctuation followed by whitespace.
*/
chunking?: "word" | "line" | RegExp | ChunkDetector;
/**
* The minimum number of milliseconds to wait between saving deltas.
* Defaults to 250.
*/
throttleMs?: number;
/**
* If set to true, this will return immediately, as it would if you weren't
* saving the deltas. Otherwise, the call will "consume" the stream with
* .consumeStream(), which waits for the stream to finish before returning.
*
* When saving deltas, you're often not interactin with the stream otherwise.
*/
returnImmediately?: boolean;
};
export const DEFAULT_STREAMING_OPTIONS = {
// This chunks by sentences / clauses. Punctuation followed by whitespace.
chunking: /[\p{P}\s]/u,
throttleMs: 250,
returnImmediately: false,
} satisfies StreamingOptions;
export function mergeTransforms<TOOLS extends ToolSet>(
options: StreamingOptions | boolean | undefined,
existing:
| StreamTextTransform<TOOLS>
| Array<StreamTextTransform<TOOLS>>
| undefined,
) {
if (!options) {
return existing;
}
const chunking =
typeof options === "boolean"
? DEFAULT_STREAMING_OPTIONS.chunking
: options.chunking;
const transforms = Array.isArray(existing)
? existing
: existing
? [existing]
: [];
transforms.push(smoothStream({ delayInMs: null, chunking }));
return transforms;
}
export class DeltaStreamer {
public streamId: string | undefined;
public readonly options: Required<StreamingOptions>;
#nextParts: TextStreamPart<ToolSet>[] = [];
#latestWrite: number = 0;
#ongoingWrite: Promise<void> | undefined;
#cursor: number = 0;
public abortController: AbortController;
constructor(
public readonly component: AgentComponent,
public readonly ctx: RunActionCtx,
options: true | StreamingOptions,
public readonly metadata: {
threadId: string;
userId?: string;
order: number;
stepOrder: number;
agentName?: string;
model?: string;
provider?: string;
providerOptions?: ProviderOptions;
abortSignal?: AbortSignal;
},
) {
this.options =
typeof options === "boolean"
? DEFAULT_STREAMING_OPTIONS
: { ...DEFAULT_STREAMING_OPTIONS, ...options };
this.#nextParts = [];
this.abortController = new AbortController();
if (metadata.abortSignal) {
metadata.abortSignal.addEventListener("abort", async () => {
if (this.streamId) {
this.abortController.abort();
const finalDelta = this.#createDelta();
await this.#ongoingWrite;
await this.ctx.runMutation(this.component.streams.abort, {
streamId: this.streamId,
reason: "abortSignal",
finalDelta,
});
}
});
}
}
public async addParts(parts: TextStreamPart<ToolSet>[]) {
if (this.abortController.signal.aborted) {
return;
}
if (!this.streamId) {
this.streamId = await this.ctx.runMutation(
this.component.streams.create,
omit(this.metadata, ["abortSignal"]),
);
}
this.#nextParts.push(...parts);
if (
!this.#ongoingWrite &&
Date.now() - this.#latestWrite >= this.options.throttleMs
) {
this.#ongoingWrite = this.#sendDelta();
}
}
async #sendDelta() {
if (this.abortController.signal.aborted) {
return;
}
const delta = this.#createDelta();
if (!delta) {
return;
}
this.#latestWrite = Date.now();
try {
const success = await this.ctx.runMutation(
this.component.streams.addDelta,
delta,
);
if (!success) {
this.abortController.abort();
}
} catch (e) {
this.abortController.abort();
throw e;
}
// Now that we've sent the delta, check if we need to send another one.
if (
this.#nextParts.length > 0 &&
Date.now() - this.#latestWrite >= this.options.throttleMs
) {
// We send again immediately with the accumulated deltas.
this.#ongoingWrite = this.#sendDelta();
} else {
this.#ongoingWrite = undefined;
}
}
#createDelta(): StreamDelta | undefined {
if (this.#nextParts.length === 0) {
return undefined;
}
const start = this.#cursor;
const end = start + this.#nextParts.length;
this.#cursor = end;
const parts = serializeTextStreamingPartsV5(this.#nextParts);
this.#nextParts = [];
if (!this.streamId) {
throw new Error("Creating a delta before the stream is created");
}
return { streamId: this.streamId, start, end, parts };
}
public async finish() {
if (!this.streamId) {
return;
}
const finalDelta = this.#createDelta();
await this.#ongoingWrite;
await this.ctx.runMutation(this.component.streams.finish, {
streamId: this.streamId,
finalDelta,
});
}
public async fail(reason: string) {
if (this.abortController.signal.aborted) {
return;
}
this.abortController.abort();
if (!this.streamId) {
return;
}
const finalDelta = this.#createDelta();
await this.#ongoingWrite;
await this.ctx.runMutation(this.component.streams.abort, {
streamId: this.streamId,
reason,
finalDelta,
});
}
}