@convex-dev/agent
Version:
A agent component for Convex.
413 lines • 15.4 kB
JavaScript
import { omit, pick } from "convex-helpers";
import { v } from "convex/values";
import { vStreamDelta, vStreamMessage, } from "../validators.js";
import { api, internal } from "./_generated/api.js";
import { internalMutation, mutation, query, action, } from "./_generated/server.js";
import schema from "./schema.js";
import { stream } from "convex-helpers/server/stream";
import { mergedStream } from "convex-helpers/server/stream";
import { paginator } from "convex-helpers/server/pagination";
import { deriveUIMessagesFromDeltas } from "../deltas.js";
import { fromUIMessages } from "../UIMessages.js";
const SECOND = 1000;
const MINUTE = 60 * SECOND;
const MAX_DELTAS_PER_REQUEST = 1000;
const MAX_DELTAS_PER_STREAM = 100;
const TIMEOUT_INTERVAL = 10 * MINUTE;
const DELETE_STREAM_DELAY = MINUTE * 5; // 5 minutes
const deltaValidator = schema.tables.streamDeltas.validator;
export const addDelta = mutation({
args: deltaValidator,
returns: v.boolean(),
handler: async (ctx, args) => {
const stream = await ctx.db.get(args.streamId);
if (!stream) {
console.warn("Stream not found", args.streamId);
return false;
}
if (stream.state.kind !== "streaming") {
return false;
}
await ctx.db.insert("streamDeltas", args);
await heartbeatStream(ctx, { streamId: args.streamId });
return true;
},
});
export const listDeltas = query({
args: {
threadId: v.id("threads"),
cursors: v.array(v.object({ streamId: v.id("streamingMessages"), cursor: v.number() })),
},
returns: v.array(vStreamDelta),
handler: async (ctx, args) => {
let totalDeltas = 0;
const deltas = [];
for (const cursor of args.cursors) {
const streamDeltas = await ctx.db
.query("streamDeltas")
.withIndex("streamId_start_end", (q) => q.eq("streamId", cursor.streamId).gte("start", cursor.cursor))
.take(Math.min(MAX_DELTAS_PER_STREAM, MAX_DELTAS_PER_REQUEST - totalDeltas));
totalDeltas += streamDeltas.length;
deltas.push(...streamDeltas.map((d) => pick(d, ["streamId", "start", "end", "parts"])));
if (totalDeltas >= MAX_DELTAS_PER_REQUEST) {
break;
}
}
return deltas;
},
});
export const create = mutation({
args: omit(schema.tables.streamingMessages.validator.fields, ["state"]),
returns: v.id("streamingMessages"),
handler: async (ctx, args) => {
const state = { kind: "streaming", lastHeartbeat: Date.now() };
// TODO: enforce order/stepOrder uniqueness?
const streamId = await ctx.db.insert("streamingMessages", {
...args,
state,
});
const timeoutFnId = await ctx.scheduler.runAfter(TIMEOUT_INTERVAL, internal.streams.timeoutStream, { streamId });
await ctx.db.patch(streamId, { state: { ...state, timeoutFnId } });
return streamId;
},
});
export const list = query({
args: {
threadId: v.id("threads"),
startOrder: v.optional(v.number()),
statuses: v.optional(v.array(v.union(v.literal("streaming"), v.literal("finished"), v.literal("aborted")))),
},
returns: v.array(vStreamMessage),
handler: async (ctx, args) => {
const statuses = args.statuses ?? ["streaming"];
const messages = await mergedStream(statuses.map((status) => stream(ctx.db, schema)
.query("streamingMessages")
.withIndex("threadId_state_order_stepOrder", (q) => q
.eq("threadId", args.threadId)
.eq("state.kind", status)
.gte("order", args.startOrder ?? 0))
.order("desc")), ["order", "stepOrder"]).take(100);
return messages.map((m) => publicStreamMessage(m));
},
});
function publicStreamMessage(m) {
return {
streamId: m._id,
status: m.state.kind,
...pick(m, [
"format",
"order",
"stepOrder",
"userId",
"agentName",
"model",
"provider",
"providerOptions",
]),
};
}
export const abortByOrder = mutation({
args: { threadId: v.id("threads"), order: v.number(), reason: v.string() },
returns: v.boolean(),
handler: async (ctx, args) => {
const streams = await ctx.db
.query("streamingMessages")
.withIndex("threadId_state_order_stepOrder", (q) => q
.eq("threadId", args.threadId)
.eq("state.kind", "streaming")
.eq("order", args.order))
.take(100);
for (const stream of streams) {
await abortById(ctx, { streamId: stream._id, reason: args.reason });
}
return streams.length > 0;
},
});
export const abort = mutation({
args: {
streamId: v.id("streamingMessages"),
reason: v.string(),
finalDelta: v.optional(deltaValidator),
},
returns: v.boolean(),
handler: abortById,
});
async function abortById(ctx, args) {
const stream = await ctx.db.get(args.streamId);
if (!stream) {
throw new Error(`Stream not found: ${args.streamId}`);
}
if (args.finalDelta) {
await ctx.db.insert("streamDeltas", args.finalDelta);
}
if (stream.state.kind !== "streaming") {
return false;
}
await cleanupTimeoutFn(ctx, stream);
await ctx.db.patch(args.streamId, {
state: { kind: "aborted", reason: args.reason },
});
return true;
}
async function cleanupTimeoutFn(ctx, stream) {
if (stream.state.kind === "streaming" && stream.state.timeoutFnId) {
const timeoutFn = await ctx.db.system.get(stream.state.timeoutFnId);
if (timeoutFn?.state.kind === "pending") {
await ctx.scheduler.cancel(stream.state.timeoutFnId);
}
}
}
// No longer used from the DeltaStreamer
export const finish = mutation({
args: {
streamId: v.id("streamingMessages"),
finalDelta: v.optional(deltaValidator),
},
returns: v.null(),
handler: finishHandler,
});
export async function finishHandler(ctx, args) {
if (args.finalDelta) {
await ctx.db.insert("streamDeltas", args.finalDelta);
}
const stream = await ctx.db.get(args.streamId);
if (!stream) {
throw new Error(`Stream not found: ${args.streamId}`);
}
if (stream.state.kind !== "streaming") {
console.warn(`Stream trying to finish ${args.streamId} but is ${stream.state.kind}`);
return;
}
await cleanupTimeoutFn(ctx, stream);
const cleanupFnId = await ctx.scheduler.runAfter(DELETE_STREAM_DELAY, api.streams.deleteStreamAsync, { streamId: args.streamId });
await ctx.db.patch(args.streamId, {
state: { kind: "finished", endedAt: Date.now(), cleanupFnId },
});
}
// TODO: use this heartbeat while streaming, every 30 seconds or so,
// then reduce the timeout to 60 seconds.
export const heartbeat = mutation({
args: { streamId: v.id("streamingMessages") },
returns: v.null(),
handler: heartbeatStream,
});
async function heartbeatStream(ctx, args) {
const stream = await ctx.db.get(args.streamId);
if (!stream) {
console.warn("Stream not found", args.streamId);
return;
}
if (stream.state.kind !== "streaming") {
return;
}
if (Date.now() - stream.state.lastHeartbeat < TIMEOUT_INTERVAL / 4) {
// Debounce heartbeating.
return;
}
if (!stream.state.timeoutFnId) {
throw new Error("Stream has no timeout function");
}
const timeoutFn = await ctx.db.system.get(stream.state.timeoutFnId);
if (!timeoutFn) {
throw new Error("Timeout function not found");
}
if (timeoutFn.state.kind !== "pending") {
throw new Error("Timeout function is not pending");
}
await ctx.scheduler.cancel(stream.state.timeoutFnId);
const timeoutFnId = await ctx.scheduler.runAfter(TIMEOUT_INTERVAL, internal.streams.timeoutStream, { streamId: args.streamId });
await ctx.db.patch(args.streamId, {
state: { kind: "streaming", lastHeartbeat: Date.now(), timeoutFnId },
});
}
export const timeoutStream = internalMutation({
args: { streamId: v.id("streamingMessages") },
returns: v.null(),
handler: async (ctx, args) => {
const stream = await ctx.db.get(args.streamId);
if (!stream || stream.state.kind !== "streaming") {
console.warn("Stream not found", args.streamId);
return;
}
await ctx.db.patch(args.streamId, {
state: { kind: "aborted", reason: "timeout" },
});
},
});
async function deletePageForStreamId(ctx, args) {
const deltas = await paginator(ctx.db, schema)
.query("streamDeltas")
.withIndex("streamId_start_end", (q) => q.eq("streamId", args.streamId))
.paginate({
numItems: MAX_DELTAS_PER_REQUEST,
cursor: args.cursor ?? null,
});
await Promise.all(deltas.page.map((d) => ctx.db.delete(d._id)));
if (deltas.isDone) {
const stream = await ctx.db.get(args.streamId);
if (stream) {
await cleanupTimeoutFn(ctx, stream);
if (stream.state.kind === "finished" && stream.state.cleanupFnId) {
await ctx.scheduler.cancel(stream.state.cleanupFnId);
}
await ctx.db.delete(args.streamId);
}
}
return deltas;
}
export async function deleteStreamsPageForThreadId(ctx, args) {
const allStreamMessages = schema.tables.streamingMessages.validator.fields.state.members
.flatMap((state) => state.fields.kind.value)
.map((stateKind) => stream(ctx.db, schema)
.query("streamingMessages")
.withIndex("threadId_state_order_stepOrder", (q) => q
.eq("threadId", args.threadId)
.eq("state.kind", stateKind)
.gte("order", args.streamOrder ?? 0)));
let deltaCursor = args.deltaCursor;
const streamMessage = await mergedStream(allStreamMessages, [
"threadId",
"state.kind",
"order",
"stepOrder",
]).first();
if (!streamMessage) {
return { isDone: true, streamOrder: undefined, deltaCursor: undefined };
}
const result = await deletePageForStreamId(ctx, {
streamId: streamMessage._id,
cursor: deltaCursor,
});
if (result.isDone) {
deltaCursor = undefined;
}
return { isDone: false, streamOrder: streamMessage.order, deltaCursor };
}
export const deleteStreamsPageForThreadIdMutation = internalMutation({
args: {
threadId: v.id("threads"),
streamOrder: v.optional(v.number()),
deltaCursor: v.optional(v.string()),
},
returns: v.object({
isDone: v.boolean(),
streamOrder: v.optional(v.number()),
deltaCursor: v.optional(v.string()),
}),
handler: deleteStreamsPageForThreadId,
});
export const deleteAllStreamsForThreadIdAsync = mutation({
args: {
threadId: v.id("threads"),
streamOrder: v.optional(v.number()),
deltaCursor: v.optional(v.string()),
},
returns: v.object({
isDone: v.boolean(),
streamOrder: v.optional(v.number()),
deltaCursor: v.optional(v.string()),
}),
handler: async (ctx, args) => {
const result = await deleteStreamsPageForThreadId(ctx, args);
if (!result.isDone) {
await ctx.scheduler.runAfter(0, api.streams.deleteAllStreamsForThreadIdAsync, {
threadId: args.threadId,
streamOrder: result.streamOrder,
deltaCursor: result.deltaCursor,
});
}
else {
await ctx.db.delete(args.threadId);
}
return result;
},
});
export const deleteStreamSync = mutation({
args: { streamId: v.id("streamingMessages") },
returns: v.null(),
handler: async (ctx, args) => {
let deltas = await deletePageForStreamId(ctx, args);
while (!deltas.isDone) {
deltas = await deletePageForStreamId(ctx, {
...args,
cursor: deltas.continueCursor,
});
}
},
});
export const deleteStreamAsync = mutation({
args: { streamId: v.id("streamingMessages"), cursor: v.optional(v.string()) },
returns: v.null(),
handler: async (ctx, args) => {
const result = await deletePageForStreamId(ctx, args);
if (!result.isDone) {
await ctx.scheduler.runAfter(0, api.streams.deleteStreamAsync, {
streamId: args.streamId,
cursor: result.continueCursor,
});
}
},
});
export const deleteAllStreamsForThreadIdSync = action({
args: { threadId: v.id("threads") },
returns: v.null(),
handler: async (ctx, args) => {
let result = await ctx.runMutation(internal.streams.deleteStreamsPageForThreadIdMutation, args);
while (!result.isDone) {
result = await ctx.runMutation(internal.streams.deleteStreamsPageForThreadIdMutation, {
...args,
streamOrder: result.streamOrder,
deltaCursor: result.deltaCursor,
});
}
},
});
export async function getStreamingMessages(ctx, threadId, order, stepOrder) {
return mergedStream(["aborted", "streaming", "finished"].map((state) => stream(ctx.db, schema)
.query("streamingMessages")
.withIndex("threadId_state_order_stepOrder", (q) => q
.eq("threadId", threadId)
.eq("state.kind", state)
.eq("order", order)
.lte("stepOrder", stepOrder))
.order("desc")), ["stepOrder"]).take(10);
}
export async function getStreamingMessagesWithMetadata(ctx, { threadId, order, stepOrder, }, metadata) {
// See if there are any streaming messages for this order
const streamingMessages = await getStreamingMessages(ctx, threadId, order, stepOrder);
const messages = (await Promise.all(streamingMessages.map(async (streamingMessage) => {
const deltas = await ctx.db
.query("streamDeltas")
.withIndex("streamId_start_end", (q) => q.eq("streamId", streamingMessage._id))
.take(1000);
const uiMessages = await deriveUIMessagesFromDeltas(threadId, [publicStreamMessage(streamingMessage)], deltas);
// We don't save messages that have already been saved
const numToSkip = stepOrder - streamingMessage.stepOrder;
const messages = await Promise.all(fromUIMessages(uiMessages, streamingMessage)
.slice(numToSkip)
.filter((m) => m.message !== undefined)
.map(async (msg) => {
return {
...pick(msg, [
"message",
"fileIds",
"status",
"finishReason",
"model",
"provider",
"providerMetadata",
"sources",
"reasoning",
"reasoningDetails",
"usage",
"warnings",
"error",
]),
...metadata,
};
}));
return messages;
}))).flat();
return messages;
}
//# sourceMappingURL=streams.js.map