UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

572 lines 23.5 kB
import { assert, omit, pick } from "convex-helpers"; import { mergedStream, stream } from "convex-helpers/server/stream"; import { paginationOptsValidator, } from "convex/server"; import { DEFAULT_MESSAGE_RANGE, DEFAULT_RECENT_MESSAGES, extractText, isTool, } from "../shared.js"; import { vMessageEmbeddingsWithDimension, vMessageStatus, vMessageWithMetadataInternal, vPaginationResult, } from "../validators.js"; import { api, internal } from "./_generated/api.js"; import { action, internalQuery, mutation, query, } from "./_generated/server.js"; import { schema, v, vMessageDoc } from "./schema.js"; import { insertVector, searchVectors } from "./vector/index.js"; import { VectorDimensions, vVectorId, } from "./vector/tables.js"; import { changeRefcount } from "./files.js"; import { getStreamingMessagesWithMetadata } from "./streams.js"; import { partial } from "convex-helpers/validators"; function publicMessage(message) { return omit(message, ["parentMessageId", "stepId", "files"]); } export async function deleteMessage(ctx, messageDoc) { await ctx.db.delete(messageDoc._id); if (messageDoc.embeddingId) { await ctx.db.delete(messageDoc.embeddingId); } if (messageDoc.fileIds) { await changeRefcount(ctx, messageDoc.fileIds, []); } } export const deleteByIds = mutation({ args: { messageIds: v.array(v.id("messages")) }, returns: v.array(v.id("messages")), handler: async (ctx, args) => { const deletedMessageIds = await Promise.all(args.messageIds.map(async (id) => { const message = await ctx.db.get(id); if (message) { await deleteMessage(ctx, message); return id; } return null; })); return deletedMessageIds.filter((id) => id !== null); }, }); export const messageStatuses = vMessageDoc.fields.status.members.map((m) => m.value); export const deleteByOrder = mutation({ args: { threadId: v.id("threads"), startOrder: v.number(), startStepOrder: v.optional(v.number()), endOrder: v.number(), endStepOrder: v.optional(v.number()), }, returns: v.object({ isDone: v.boolean(), lastOrder: v.optional(v.number()), lastStepOrder: v.optional(v.number()), }), handler: async (ctx, args) => { const messages = await orderedMessagesStream(ctx, args.threadId, "asc", args.startOrder) .narrow({ lowerBound: args.startStepOrder ? [args.startOrder, args.startStepOrder] : [args.startOrder], lowerBoundInclusive: true, upperBound: args.endStepOrder ? [args.endOrder, args.endStepOrder] : [args.endOrder], upperBoundInclusive: false, }) .take(64); await Promise.all(messages.map((m) => deleteMessage(ctx, m))); return { isDone: messages.length < 64, lastOrder: messages.at(-1)?.order, lastStepOrder: messages.at(-1)?.stepOrder, }; }, }); const addMessagesArgs = { userId: v.optional(v.string()), threadId: v.id("threads"), promptMessageId: v.optional(v.id("messages")), agentName: v.optional(v.string()), messages: v.array(vMessageWithMetadataInternal), embeddings: v.optional(vMessageEmbeddingsWithDimension), failPendingSteps: v.optional(v.boolean()), // A pending message to update. If the pending message failed, abort. pendingMessageId: v.optional(v.id("messages")), }; export const addMessages = mutation({ args: addMessagesArgs, handler: addMessagesHandler, returns: v.object({ messages: v.array(vMessageDoc) }), }); async function addMessagesHandler(ctx, args) { let userId = args.userId; const threadId = args.threadId; if (!userId && args.threadId) { const thread = await ctx.db.get(args.threadId); assert(thread, `Thread ${args.threadId} not found`); userId = thread.userId; } const { embeddings, failPendingSteps, messages, promptMessageId, pendingMessageId, ...rest } = args; const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId)); if (failPendingSteps) { assert(args.threadId, "threadId is required to fail pending steps"); const pendingMessages = await ctx.db .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => q.eq("threadId", threadId).eq("status", "pending")) .order("desc") .take(100); await Promise.all(pendingMessages .filter((m) => !promptMessage || m.order === promptMessage.order) .filter((m) => !pendingMessageId || m._id !== pendingMessageId) .map(async (m) => { if (m.embeddingId) { await ctx.db.delete(m.embeddingId); } await ctx.db.patch(m._id, { status: "failed", error: "Restarting", embeddingId: undefined, }); })); } let order, stepOrder; let fail = false; let error; if (promptMessageId) { assert(promptMessage, `Parent message ${promptMessageId} not found`); if (promptMessage.status === "failed") { fail = true; error = promptMessage.error ?? error ?? "The prompt message failed"; } order = promptMessage.order; // Defend against there being existing messages with this parent. const maxMessage = await getMaxMessage(ctx, threadId, order); stepOrder = maxMessage?.stepOrder ?? promptMessage.stepOrder; } else { const maxMessage = await getMaxMessage(ctx, threadId); order = maxMessage?.order ?? -1; stepOrder = maxMessage?.stepOrder ?? -1; } const toReturn = []; if (embeddings) { assert(embeddings.vectors.length === messages.length, "embeddings.vectors.length must match messages.length"); } for (let i = 0; i < messages.length; i++) { const message = messages[i]; let embeddingId; if (embeddings && embeddings.vectors[i] && !fail && message.status !== "failed") { embeddingId = await insertVector(ctx, embeddings.dimension, { vector: embeddings.vectors[i], model: embeddings.model, table: "messages", userId, threadId, }); } const messageDoc = { ...rest, ...message, embeddingId, parentMessageId: promptMessageId, userId, tool: isTool(message.message), text: extractText(message.message), status: fail ? "failed" : (message.status ?? "success"), error: fail ? error : message.error, }; // If there is a pending message, we replace that one with the first message // and subsequent ones will follow the regular order/subOrder advancement. if (i === 0 && pendingMessageId) { const pendingMessage = await ctx.db.get(pendingMessageId); assert(pendingMessage, `Pending msg ${pendingMessageId} not found`); if (pendingMessage.status === "failed") { fail = true; error = `Trying to update a message that failed: ${pendingMessageId}, ` + `error: ${pendingMessage.error ?? error}`; messageDoc.status = "failed"; messageDoc.error = error; } if (message.fileIds) { await changeRefcount(ctx, pendingMessage.fileIds ?? [], message.fileIds); } await ctx.db.replace(pendingMessage._id, { ...messageDoc, order: pendingMessage.order, stepOrder: pendingMessage.stepOrder, }); toReturn.push(pendingMessage); continue; } if (message.message.role === "user") { if (promptMessage && promptMessage.order === order) { // see if there's a later message than the parent message order const maxMessage = await getMaxMessage(ctx, threadId); order = (maxMessage?.order ?? order) + 1; } else { order++; } stepOrder = 0; } else { if (order < 0) { order = 0; } stepOrder++; } const messageId = await ctx.db.insert("messages", { ...messageDoc, order, stepOrder, }); if (message.fileIds) { await changeRefcount(ctx, [], message.fileIds); } // TODO: delete the associated stream data for the order/stepOrder toReturn.push((await ctx.db.get(messageId))); } return { messages: toReturn.map(publicMessage) }; } // exported for tests export async function getMaxMessage(ctx, threadId, order) { return orderedMessagesStream(ctx, threadId, "desc", order).first(); } function orderedMessagesStream(ctx, threadId, sortOrder, order) { return mergedStream([true, false].flatMap((tool) => messageStatuses.map((status) => stream(ctx.db, schema) .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => { const qq = q .eq("threadId", threadId) .eq("status", status) .eq("tool", tool); if (order !== undefined) { return qq.eq("order", order); } return qq; }) .order(sortOrder))), ["order", "stepOrder"]); } export const finalizeMessage = mutation({ args: { messageId: v.id("messages"), result: v.union(v.object({ status: v.literal("success") }), v.object({ status: v.literal("failed"), error: v.string() })), }, returns: v.null(), handler: async (ctx, { messageId, result }) => { const message = await ctx.db.get(messageId); assert(message, `Message ${messageId} not found`); if (message.status !== "pending") { console.log("Trying to finalize a message that's already", message.status); return; } // See if we can add any in-progress data if (message.message === undefined) { const messages = await getStreamingMessagesWithMetadata(ctx, message, result); if (messages.length > 0) { await addMessagesHandler(ctx, { messages, threadId: message.threadId, agentName: message.agentName, failPendingSteps: false, pendingMessageId: messageId, userId: message.userId, embeddings: undefined, }); return; } } if (result.status === "failed") { if (message.embeddingId) { await ctx.db.delete(message.embeddingId); } await ctx.db.patch(messageId, { status: "failed", error: result.error, embeddingId: undefined, }); } else { await ctx.db.patch(messageId, { status: "success" }); } }, }); export const updateMessage = mutation({ args: { messageId: v.id("messages"), patch: v.object(partial(pick(schema.tables.messages.validator.fields, [ "message", "fileIds", "status", "error", "model", "provider", "providerOptions", "finishReason", ]))), }, returns: vMessageDoc, handler: async (ctx, args) => { const message = await ctx.db.get(args.messageId); assert(message, `Message ${args.messageId} not found`); if (args.patch.fileIds) { await changeRefcount(ctx, message.fileIds ?? [], args.patch.fileIds); } const patch = { ...args.patch }; if (args.patch.message !== undefined) { patch.message = args.patch.message; patch.tool = isTool(args.patch.message); patch.text = extractText(args.patch.message); } if (args.patch.status === "failed") { if (message.embeddingId) { await ctx.db.delete(message.embeddingId); } patch.embeddingId = undefined; } await ctx.db.patch(args.messageId, patch); return publicMessage((await ctx.db.get(args.messageId))); }, }); export const listMessagesByThreadId = query({ args: { threadId: v.id("threads"), excludeToolMessages: v.optional(v.boolean()), /** What order to sort the messages in. To get the latest, use "desc". */ order: v.union(v.literal("asc"), v.literal("desc")), paginationOpts: v.optional(paginationOptsValidator), statuses: v.optional(v.array(vMessageStatus)), upToAndIncludingMessageId: v.optional(v.id("messages")), }, handler: async (ctx, args) => { const statuses = args.statuses ?? vMessageStatus.members.map((m) => m.value); const last = args.upToAndIncludingMessageId && (await ctx.db.get(args.upToAndIncludingMessageId)); assert(!last || last.threadId === args.threadId, "upToAndIncludingMessageId must be a message in the thread"); const toolOptions = args.excludeToolMessages ? [false] : [true, false]; const order = args.order ?? "desc"; const streams = toolOptions.flatMap((tool) => statuses.map((status) => stream(ctx.db, schema) .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => { const qq = q .eq("threadId", args.threadId) .eq("status", status) .eq("tool", tool); if (last) { return qq.lte("order", last.order); } return qq; }) .order(order) .filterWith( // We allow all messages on the same order. async (m) => !last || m.order < last.order || m.order === last.order))); const messages = await mergedStream(streams, [ "order", "stepOrder", ]).paginate(args.paginationOpts ?? { numItems: DEFAULT_RECENT_MESSAGES, cursor: null, }); return { ...messages, page: messages.page.map(publicMessage) }; }, returns: vPaginationResult(vMessageDoc), }); export const getMessagesByIds = query({ args: { messageIds: v.array(v.id("messages")) }, handler: async (ctx, args) => { return (await Promise.all(args.messageIds.map((id) => ctx.db.get(id)))).map((m) => (m ? publicMessage(m) : null)); }, returns: v.array(v.union(v.null(), vMessageDoc)), }); export const searchMessages = action({ args: { threadId: v.optional(v.id("threads")), searchAllMessagesForUserId: v.optional(v.string()), beforeMessageId: v.optional(v.id("messages")), embedding: v.optional(v.array(v.number())), embeddingModel: v.optional(v.string()), text: v.optional(v.string()), limit: v.number(), vectorScoreThreshold: v.optional(v.number()), messageRange: v.optional(v.object({ before: v.number(), after: v.number() })), }, returns: v.array(vMessageDoc), handler: async (ctx, args) => { assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId"); const limit = args.limit; let textSearchMessages; if (args.text) { textSearchMessages = await ctx.runQuery(api.messages.textSearch, { searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, text: args.text, limit, beforeMessageId: args.beforeMessageId, }); } if (args.embedding) { const dimension = args.embedding.length; if (!VectorDimensions.includes(dimension)) { throw new Error(`Unsupported embedding dimension: ${dimension}`); } const vectors = (await searchVectors(ctx, args.embedding, { dimension, model: args.embeddingModel ?? "unknown", table: "messages", searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, limit, })).filter((v) => v._score > (args.vectorScoreThreshold ?? 0)); // Reciprocal rank fusion const k = 10; const textEmbeddingIds = textSearchMessages?.map((m) => m.embeddingId); const vectorScores = vectors .map((v, i) => ({ id: v._id, score: 1 / (i + k) + 1 / ((textEmbeddingIds?.indexOf(v._id) ?? Infinity) + k), })) .sort((a, b) => b.score - a.score); const embeddingIds = vectorScores.slice(0, limit).map((v) => v.id); const messages = await ctx.runQuery(internal.messages._fetchSearchMessages, { searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, embeddingIds, textSearchMessages: textSearchMessages?.filter((m) => !embeddingIds.includes(m.embeddingId)), messageRange: args.messageRange ?? DEFAULT_MESSAGE_RANGE, beforeMessageId: args.beforeMessageId, limit, }); return messages; } return textSearchMessages?.flat() ?? []; }, }); export const _fetchSearchMessages = internalQuery({ args: { threadId: v.optional(v.id("threads")), embeddingIds: v.array(vVectorId), searchAllMessagesForUserId: v.optional(v.string()), textSearchMessages: v.optional(v.array(vMessageDoc)), messageRange: v.object({ before: v.number(), after: v.number() }), beforeMessageId: v.optional(v.id("messages")), limit: v.number(), }, returns: v.array(vMessageDoc), handler: async (ctx, args) => { const beforeMessage = args.beforeMessageId && (await ctx.db.get(args.beforeMessageId)); const { searchAllMessagesForUserId, threadId } = args; assert(searchAllMessagesForUserId || threadId, "Specify searchAllMessagesForUserId or threadId to search"); let messages = (await Promise.all(args.embeddingIds.map((embeddingId) => ctx.db .query("messages") .withIndex("embeddingId_threadId", (q) => searchAllMessagesForUserId ? q.eq("embeddingId", embeddingId) : q.eq("embeddingId", embeddingId).eq("threadId", threadId)) .filter((q) => q.and(q.eq(q.field("status"), "success"), searchAllMessagesForUserId ? q.eq(q.field("userId"), searchAllMessagesForUserId) : q.eq(q.field("threadId"), threadId))) .first()))) .filter((m) => m !== undefined && m !== null && !m.tool && (!beforeMessage || m.order < beforeMessage.order || (m.order === beforeMessage.order && m.stepOrder < beforeMessage.stepOrder))) .map(publicMessage); messages.push(...(args.textSearchMessages ?? [])); // TODO: prioritize more recent messages messages.sort((a, b) => a.order - b.order); messages = messages.slice(0, args.limit); // Fetch the surrounding messages if (!threadId) { return messages.sort((a, b) => a.order - b.order); } const included = {}; for (const m of messages) { const searchId = m.threadId ?? m.userId; if (!included[searchId]) { included[searchId] = new Set(); } included[searchId].add(m.order); } const ranges = {}; const { before, after } = args.messageRange; for (const m of messages) { const searchId = m.threadId ?? m.userId; const order = m.order; let earliest = order - before; let latest = order + after; for (; earliest <= latest; earliest++) { if (!included[searchId].has(earliest)) { break; } } for (; latest >= earliest; latest--) { if (!included[searchId].has(latest)) { break; } } for (let i = earliest; i <= latest; i++) { included[searchId].add(i); } if (earliest !== latest) { const surrounding = await ctx.db .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => q .eq("threadId", m.threadId) .eq("status", "success") .eq("tool", false) .gte("order", earliest) .lte("order", latest)) .collect(); if (!ranges[searchId]) { ranges[searchId] = []; } ranges[searchId].push(...surrounding); } } for (const r of Object.values(ranges).flat()) { if (!messages.some((m) => m._id === r._id)) { messages.push(publicMessage(r)); } } return messages.sort((a, b) => a.order - b.order); }, }); // returns ranges of messages in order of text search relevance, // excluding duplicates in later ranges. export const textSearch = query({ args: { threadId: v.optional(v.id("threads")), searchAllMessagesForUserId: v.optional(v.string()), text: v.string(), limit: v.number(), beforeMessageId: v.optional(v.id("messages")), }, handler: async (ctx, args) => { assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId"); const beforeMessage = args.beforeMessageId && (await ctx.db.get(args.beforeMessageId)); const order = beforeMessage?.order; const messages = await ctx.db .query("messages") .withSearchIndex("text_search", (q) => args.searchAllMessagesForUserId ? q .search("text", args.text) .eq("userId", args.searchAllMessagesForUserId) : q.search("text", args.text).eq("threadId", args.threadId)) // Just in case tool messages slip through .filter((q) => { const qq = q.eq(q.field("tool"), false); if (order) { return q.and(qq, q.lte(q.field("order"), order)); } return qq; }) .take(args.limit); return messages .filter((m) => !beforeMessage || m.order < beforeMessage.order || (m.order === beforeMessage.order && m.stepOrder < beforeMessage.stepOrder)) .map(publicMessage); }, returns: v.array(vMessageDoc), }); //# sourceMappingURL=messages.js.map