UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

747 lines 30.1 kB
import { assert, omit, pick } from "convex-helpers"; import { mergedStream, stream } from "convex-helpers/server/stream"; import { paginationOptsValidator, } from "convex/server"; import { DEFAULT_MESSAGE_RANGE, DEFAULT_RECENT_MESSAGES, extractText, isTool, sorted, } from "../shared.js"; import { vMessageDoc, vMessageEmbeddingsWithDimension, vMessageStatus, vMessageWithMetadataInternal, vPaginationResult, } from "../validators.js"; import { api, internal } from "./_generated/api.js"; import { action, internalMutation, internalQuery, mutation, query, } from "./_generated/server.js"; import { schema, v } from "./schema.js"; import { insertVector, searchVectors } from "./vector/index.js"; import { validateVectorDimension, vVectorId, } from "./vector/tables.js"; import { changeRefcount } from "./files.js"; import { getStreamingMessagesWithMetadata } from "./streams.js"; import { partial } from "convex-helpers/validators"; function publicMessage(message) { return omit(message, ["parentMessageId", "stepId", "files"]); } export async function deleteMessage(ctx, messageDoc) { await ctx.db.delete(messageDoc._id); if (messageDoc.embeddingId) { await ctx.db.delete(messageDoc.embeddingId); } if (messageDoc.fileIds) { await changeRefcount(ctx, messageDoc.fileIds, []); } } export const deleteByIds = mutation({ args: { messageIds: v.array(v.id("messages")) }, returns: v.array(v.id("messages")), handler: async (ctx, args) => { const deletedMessageIds = await Promise.all(args.messageIds.map(async (id) => { const message = await ctx.db.get(id); if (message) { await deleteMessage(ctx, message); return id; } return null; })); return deletedMessageIds.filter((id) => id !== null); }, }); export const messageStatuses = vMessageDoc.fields.status.members.map((m) => m.value); export const deleteByOrder = mutation({ args: { threadId: v.id("threads"), startOrder: v.number(), startStepOrder: v.optional(v.number()), endOrder: v.number(), endStepOrder: v.optional(v.number()), }, returns: v.object({ isDone: v.boolean(), lastOrder: v.optional(v.number()), lastStepOrder: v.optional(v.number()), }), handler: async (ctx, args) => { const messages = await orderedMessagesStream(ctx, { threadId: args.threadId, sortOrder: "asc", startOrder: args.startOrder, startOrderBound: "gte", }) .narrow({ lowerBound: args.startStepOrder ? [args.startOrder, args.startStepOrder] : [args.startOrder], lowerBoundInclusive: true, upperBound: args.endStepOrder ? [args.endOrder, args.endStepOrder] : [args.endOrder], upperBoundInclusive: false, }) .take(64); await Promise.all(messages.map((m) => deleteMessage(ctx, m))); return { isDone: messages.length < 64, lastOrder: messages.at(-1)?.order, lastStepOrder: messages.at(-1)?.stepOrder, }; }, }); const addMessagesArgs = { userId: v.optional(v.string()), threadId: v.id("threads"), promptMessageId: v.optional(v.id("messages")), agentName: v.optional(v.string()), messages: v.array(vMessageWithMetadataInternal), embeddings: v.optional(vMessageEmbeddingsWithDimension), failPendingSteps: v.optional(v.boolean()), // A pending message to update. If the pending message failed, abort. pendingMessageId: v.optional(v.id("messages")), // if set to true, these messages will not show up in text or vector search // results for the userId hideFromUserIdSearch: v.optional(v.boolean()), }; export const addMessages = mutation({ args: addMessagesArgs, handler: addMessagesHandler, returns: v.object({ messages: v.array(vMessageDoc) }), }); async function addMessagesHandler(ctx, args) { let userId = args.userId; const threadId = args.threadId; if (!userId && args.threadId) { const thread = await ctx.db.get(args.threadId); assert(thread, `Thread ${args.threadId} not found`); userId = thread.userId; } const { embeddings, failPendingSteps, messages, promptMessageId, pendingMessageId, hideFromUserIdSearch, ...rest } = args; const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId)); if (failPendingSteps) { assert(args.threadId, "threadId is required to fail pending steps"); const pendingMessages = await ctx.db .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => q.eq("threadId", threadId).eq("status", "pending")) .order("desc") .take(100); await Promise.all(pendingMessages .filter((m) => !promptMessage || m.order === promptMessage.order) .filter((m) => !pendingMessageId || m._id !== pendingMessageId) .map(async (m) => { if (m.embeddingId) { await ctx.db.delete(m.embeddingId); } await ctx.db.patch(m._id, { status: "failed", error: "Restarting", embeddingId: undefined, }); })); } let order, stepOrder; let fail = false; let error; if (promptMessageId) { assert(promptMessage, `Parent message ${promptMessageId} not found`); if (promptMessage.status === "failed") { fail = true; error = promptMessage.error ?? error ?? "The prompt message failed"; } order = promptMessage.order; // Defend against there being existing messages with this parent. const maxMessage = await getMaxMessage(ctx, threadId, order); stepOrder = maxMessage?.stepOrder ?? promptMessage.stepOrder; } else { const maxMessage = await getMaxMessage(ctx, threadId); order = maxMessage?.order ?? -1; stepOrder = maxMessage?.stepOrder ?? -1; } const toReturn = []; if (embeddings) { assert(embeddings.vectors.length === messages.length, "embeddings.vectors.length must match messages.length"); } for (let i = 0; i < messages.length; i++) { const message = messages[i]; let embeddingId; if (embeddings && embeddings.vectors[i] && !fail && message.status !== "failed") { embeddingId = await insertVector(ctx, embeddings.dimension, { vector: embeddings.vectors[i], model: embeddings.model, table: "messages", userId: hideFromUserIdSearch ? undefined : userId, threadId, }); } const messageDoc = { ...rest, ...message, embeddingId, parentMessageId: promptMessageId, userId, tool: isTool(message.message), text: hideFromUserIdSearch ? undefined : extractText(message.message), status: fail ? "failed" : (message.status ?? "success"), error: fail ? error : message.error, }; // If there is a pending message, we replace that one with the first message // and subsequent ones will follow the regular order/subOrder advancement. if (i === 0 && pendingMessageId) { const pendingMessage = await ctx.db.get(pendingMessageId); assert(pendingMessage, `Pending msg ${pendingMessageId} not found`); if (pendingMessage.status === "failed") { fail = true; error = `Trying to update a message that failed: ${pendingMessageId}, ` + `error: ${pendingMessage.error ?? error}`; messageDoc.status = "failed"; messageDoc.error = error; } if (message.fileIds) { await changeRefcount(ctx, pendingMessage.fileIds ?? [], message.fileIds); } await ctx.db.replace(pendingMessage._id, { ...messageDoc, order: pendingMessage.order, stepOrder: pendingMessage.stepOrder, }); toReturn.push(pendingMessage); continue; } if (message.message.role === "user") { if (promptMessage && promptMessage.order === order) { // see if there's a later message than the parent message order const maxMessage = await getMaxMessage(ctx, threadId); order = (maxMessage?.order ?? order) + 1; } else { order++; } stepOrder = 0; } else { if (order < 0) { order = 0; } stepOrder++; } const messageId = await ctx.db.insert("messages", { ...messageDoc, order, stepOrder, }); if (message.fileIds) { await changeRefcount(ctx, [], message.fileIds); } // TODO: delete the associated stream data for the order/stepOrder toReturn.push((await ctx.db.get(messageId))); } return { messages: toReturn.map(publicMessage) }; } // exported for tests export async function getMaxMessage(ctx, threadId, order) { return orderedMessagesStream(ctx, { threadId, sortOrder: "desc", startOrder: order, startOrderBound: "eq", }).first(); } function orderedMessagesStream(ctx, args) { return mergedStream([true, false].flatMap((tool) => messageStatuses.map((status) => stream(ctx.db, schema) .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => { const qq = q .eq("threadId", args.threadId) .eq("status", status) .eq("tool", tool); if (args.startOrder !== undefined) { if (args.startOrderBound === "gte") { return qq.gte("order", args.startOrder); } else { return qq.eq("order", args.startOrder); } } return qq; }) .order(args.sortOrder))), ["order", "stepOrder"]); } export const finalizeMessage = mutation({ args: { messageId: v.id("messages"), result: v.union(v.object({ status: v.literal("success") }), v.object({ status: v.literal("failed"), error: v.string() })), }, returns: v.null(), handler: async (ctx, { messageId, result }) => { const message = await ctx.db.get(messageId); assert(message, `Message ${messageId} not found`); if (message.status !== "pending") { console.debug("Trying to finalize a message that's already", message.status); return; } // See if we can add any in-progress data if (!message.message?.content.length) { const messages = await getStreamingMessagesWithMetadata(ctx, message, result); if (messages.length > 0) { await addMessagesHandler(ctx, { messages, threadId: message.threadId, agentName: message.agentName, failPendingSteps: false, pendingMessageId: messageId, userId: message.userId, embeddings: undefined, }); return; } } if (result.status === "failed") { if (message.embeddingId) { await ctx.db.delete(message.embeddingId); } await ctx.db.patch(messageId, { status: "failed", error: result.error, embeddingId: undefined, }); } else { await ctx.db.patch(messageId, { status: "success" }); } }, }); export const updateMessage = mutation({ args: { messageId: v.id("messages"), patch: v.object(partial(pick(schema.tables.messages.validator.fields, [ "message", "fileIds", "status", "error", "model", "provider", "providerOptions", "finishReason", ]))), }, returns: vMessageDoc, handler: async (ctx, args) => { const message = await ctx.db.get(args.messageId); assert(message, `Message ${args.messageId} not found`); if (args.patch.fileIds) { await changeRefcount(ctx, message.fileIds ?? [], args.patch.fileIds); } const patch = { ...args.patch }; if (args.patch.message !== undefined) { patch.message = args.patch.message; patch.tool = isTool(args.patch.message); patch.text = extractText(args.patch.message); } if (args.patch.status === "failed") { if (message.embeddingId) { await ctx.db.delete(message.embeddingId); } patch.embeddingId = undefined; } await ctx.db.patch(args.messageId, patch); return publicMessage((await ctx.db.get(args.messageId))); }, }); const cloneMessageArgs = { sourceThreadId: v.id("threads"), targetThreadId: v.id("threads"), // defaults to false, so searching for a message by userId will not find // these copies copyUserIdForVectorSearch: v.optional(v.boolean()), // defaults to false, so tool calls & responses will be copied excludeToolMessages: v.optional(v.boolean()), // defaults to copying all messages, but you could just copy success messages. statuses: v.optional(v.array(vMessageStatus)), // stop at this message id upToAndIncludingMessageId: v.optional(v.id("messages")), // defaults to 0. the messages will be inserted starting at this order. insertAtOrder: v.optional(v.number()), }; export const cloneMessageBatch = internalMutation({ args: { ...cloneMessageArgs, paginationOpts: paginationOptsValidator, }, handler: async (ctx, args) => { const orderOffset = args.insertAtOrder ?? 0; const result = await listMessagesByThreadIdHandler(ctx, { threadId: args.sourceThreadId, excludeToolMessages: args.excludeToolMessages, order: "desc", paginationOpts: args.paginationOpts, statuses: args.statuses, upToAndIncludingMessageId: args.upToAndIncludingMessageId, }); const existing = result.page.length === 0 ? [] : await mergedStream([true, false].flatMap((tool) => messageStatuses.map((status) => stream(ctx.db, schema) .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => q .eq("threadId", args.targetThreadId) .eq("status", status) .eq("tool", tool) .gte("order", result.page[0].order) .lte("order", result.page[result.page.length - 1].order)))), ["order", "stepOrder"]).collect(); await Promise.all(result.page .filter((m) => !existing.some((e) => e.order === m.order && e.stepOrder === m.stepOrder)) .map(async (m) => { // update file refs if (m.fileIds) { await changeRefcount(ctx, [], m.fileIds); } let embeddingId = undefined; if (m.embeddingId) { const vector = await ctx.db.get(m.embeddingId); assert(vector, `Vector ${m.embeddingId} not found`); const dimension = vector.vector.length; validateVectorDimension(dimension); embeddingId = await insertVector(ctx, dimension, { ...pick(vector, ["model", "table", "vector"]), userId: args.copyUserIdForVectorSearch ? vector.userId : undefined, threadId: args.targetThreadId, }); } await ctx.db.insert("messages", { ...omit(m, [ "_id", "_creationTime", "threadId", "order", "embeddingId", ]), embeddingId, threadId: args.targetThreadId, order: orderOffset + m.order, }); })); return { numCopied: result.page.length, continueCursor: result.continueCursor, isDone: result.isDone, }; }, }); export const cloneThread = action({ args: { ...cloneMessageArgs, batchSize: v.optional(v.number()), // how many messages to copy limit: v.optional(v.number()), }, returns: v.number(), handler: async (ctx, args) => { let cursor = null; let copiedSoFar = 0; while (copiedSoFar < (args.limit ?? Infinity)) { const numToCopy = Math.min(args.batchSize ?? DEFAULT_RECENT_MESSAGES, args.limit ?? Infinity - copiedSoFar); const result = await ctx.runMutation(internal.messages.cloneMessageBatch, { ...args, paginationOpts: { cursor, numItems: numToCopy, }, }); copiedSoFar += result.numCopied; cursor = result.continueCursor; if (result.isDone) { break; } } return copiedSoFar; }, }); export const listMessagesByThreadIdArgs = { threadId: v.id("threads"), excludeToolMessages: v.optional(v.boolean()), /** What order to sort the messages in. To get the latest, use "desc". */ order: v.union(v.literal("asc"), v.literal("desc")), paginationOpts: v.optional(paginationOptsValidator), statuses: v.optional(v.array(vMessageStatus)), upToAndIncludingMessageId: v.optional(v.id("messages")), }; export const listMessagesByThreadId = query({ args: listMessagesByThreadIdArgs, handler: async (ctx, args) => { const messages = await listMessagesByThreadIdHandler(ctx, args); return { ...messages, page: messages.page.map(publicMessage) }; }, returns: vPaginationResult(vMessageDoc), }); async function listMessagesByThreadIdHandler(ctx, args) { const statuses = args.statuses ?? vMessageStatus.members.map((m) => m.value); const last = args.upToAndIncludingMessageId && (await ctx.db.get(args.upToAndIncludingMessageId)); assert(!last || last.threadId === args.threadId, "upToAndIncludingMessageId must be a message in the thread"); const toolOptions = args.excludeToolMessages ? [false] : [true, false]; const order = args.order ?? "desc"; const streams = toolOptions.flatMap((tool) => statuses.map((status) => stream(ctx.db, schema) .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => { const qq = q .eq("threadId", args.threadId) .eq("status", status) .eq("tool", tool); if (last) { return qq.lte("order", last.order); } return qq; }) .order(order) .filterWith( // We allow all messages on the same order. async (m) => !last || m.order <= last.order))); const messages = await mergedStream(streams, ["order", "stepOrder"]).paginate(args.paginationOpts ?? { numItems: DEFAULT_RECENT_MESSAGES, cursor: null, }); if (messages.page.length === 0) { messages.isDone = true; } return messages; } export const getMessagesByIds = query({ args: { messageIds: v.array(v.id("messages")) }, handler: async (ctx, args) => { return (await Promise.all(args.messageIds.map((id) => ctx.db.get(id)))).map((m) => (m ? publicMessage(m) : null)); }, returns: v.array(v.union(v.null(), vMessageDoc)), }); export const searchMessages = action({ args: { threadId: v.optional(v.id("threads")), searchAllMessagesForUserId: v.optional(v.string()), targetMessageId: v.optional(v.id("messages")), embedding: v.optional(v.array(v.number())), embeddingModel: v.optional(v.string()), text: v.optional(v.string()), textSearch: v.optional(v.boolean()), vectorSearch: v.optional(v.boolean()), limit: v.number(), vectorScoreThreshold: v.optional(v.number()), messageRange: v.optional(v.object({ before: v.number(), after: v.number() })), }, returns: v.array(vMessageDoc), handler: async (ctx, args) => { assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId"); const limit = args.limit; let textSearchMessages; if (args.textSearch) { textSearchMessages = await ctx.runQuery(api.messages.textSearch, { searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, targetMessageId: args.targetMessageId, text: args.text, limit, }); } if (args.vectorSearch) { let embedding = args.embedding; let model = args.embeddingModel; if (!embedding) { if (args.targetMessageId) { const target = await ctx.runQuery(api.messages.getMessageSearchFields, { messageId: args.targetMessageId, }); assert(target, "Target message embedding not found."); embedding = target.embedding; model = target.embeddingModel; } } assert(embedding && model, "Embedding missing"); const dimension = embedding.length; validateVectorDimension(dimension); const vectors = (await searchVectors(ctx, embedding, { dimension, model, table: "messages", searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, limit, })).filter((v) => v._score > (args.vectorScoreThreshold ?? 0)); // Reciprocal rank fusion const k = 10; const textEmbeddingIds = textSearchMessages?.map((m) => m.embeddingId); const vectorScores = vectors .map((v, i) => ({ id: v._id, score: 1 / (i + k) + 1 / ((textEmbeddingIds?.indexOf(v._id) ?? Infinity) + k), })) .sort((a, b) => b.score - a.score); const embeddingIds = vectorScores.slice(0, limit).map((v) => v.id); const messages = await ctx.runQuery(internal.messages._fetchSearchMessages, { searchAllMessagesForUserId: args.searchAllMessagesForUserId, threadId: args.threadId, embeddingIds, textSearchMessages: textSearchMessages?.filter((m) => !embeddingIds.includes(m.embeddingId)), messageRange: args.messageRange ?? DEFAULT_MESSAGE_RANGE, beforeMessageId: args.targetMessageId, limit, }); return messages; } return textSearchMessages?.flat() ?? []; }, }); export const _fetchSearchMessages = internalQuery({ args: { threadId: v.optional(v.id("threads")), embeddingIds: v.array(vVectorId), searchAllMessagesForUserId: v.optional(v.string()), textSearchMessages: v.optional(v.array(vMessageDoc)), messageRange: v.object({ before: v.number(), after: v.number() }), beforeMessageId: v.optional(v.id("messages")), limit: v.number(), }, returns: v.array(vMessageDoc), handler: async (ctx, args) => { const beforeMessage = args.beforeMessageId && (await ctx.db.get(args.beforeMessageId)); const { searchAllMessagesForUserId, threadId } = args; assert(searchAllMessagesForUserId || threadId, "Specify searchAllMessagesForUserId or threadId to search"); let messages = (await Promise.all(args.embeddingIds.map((embeddingId) => ctx.db .query("messages") .withIndex("embeddingId_threadId", (q) => searchAllMessagesForUserId ? q.eq("embeddingId", embeddingId) : q.eq("embeddingId", embeddingId).eq("threadId", threadId)) .filter((q) => q.and(q.eq(q.field("status"), "success"), searchAllMessagesForUserId ? q.eq(q.field("userId"), searchAllMessagesForUserId) : q.eq(q.field("threadId"), threadId))) .first()))) .filter((m) => m !== undefined && m !== null && !m.tool && (!beforeMessage || m.order < beforeMessage.order || (m.order === beforeMessage.order && m.stepOrder < beforeMessage.stepOrder))) .map(publicMessage); messages.push(...(args.textSearchMessages ?? [])); // TODO: prioritize more recent messages messages = sorted(messages); messages = messages.slice(0, args.limit); // Fetch the surrounding messages if (!threadId) { return messages; } const included = {}; for (const m of messages) { const searchId = m.threadId ?? m.userId; if (!included[searchId]) { included[searchId] = new Set(); } included[searchId].add(m.order); } const ranges = {}; const { before, after } = args.messageRange; for (const m of messages) { const searchId = m.threadId ?? m.userId; const order = m.order; let earliest = order - before; let latest = order + after; for (; earliest <= latest; earliest++) { if (!included[searchId].has(earliest)) { break; } } for (; latest >= earliest; latest--) { if (!included[searchId].has(latest)) { break; } } for (let i = earliest; i <= latest; i++) { included[searchId].add(i); } if (earliest !== latest) { const surrounding = await ctx.db .query("messages") .withIndex("threadId_status_tool_order_stepOrder", (q) => q .eq("threadId", m.threadId) .eq("status", "success") .eq("tool", false) .gte("order", earliest) .lte("order", latest)) .collect(); if (!ranges[searchId]) { ranges[searchId] = []; } ranges[searchId].push(...surrounding); } } for (const r of Object.values(ranges).flat()) { if (!messages.some((m) => m._id === r._id)) { messages.push(publicMessage(r)); } } return sorted(messages); }, }); // returns ranges of messages in order of text search relevance, // excluding duplicates in later ranges. export const textSearch = query({ args: { threadId: v.optional(v.id("threads")), searchAllMessagesForUserId: v.optional(v.string()), text: v.optional(v.string()), targetMessageId: v.optional(v.id("messages")), limit: v.number(), }, handler: async (ctx, args) => { assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId"); const targetMessage = args.targetMessageId && (await ctx.db.get(args.targetMessageId)); const order = targetMessage?.order; const text = args.text || targetMessage?.text; if (!text) { console.warn("No text to search", targetMessage, args.text); return []; } const messages = await ctx.db .query("messages") .withSearchIndex("text_search", (q) => args.searchAllMessagesForUserId ? q.search("text", text).eq("userId", args.searchAllMessagesForUserId) : q.search("text", text).eq("threadId", args.threadId)) // Just in case tool messages slip through .filter((q) => { const qq = q.eq(q.field("tool"), false); if (order) { return q.and(qq, q.lte(q.field("order"), order)); } return qq; }) .take(args.limit); return messages .filter((m) => !targetMessage || m.order < targetMessage.order || (m.order === targetMessage.order && m.stepOrder < targetMessage.stepOrder)) .map(publicMessage); }, returns: v.array(vMessageDoc), }); export const getMessageSearchFields = query({ args: { messageId: v.id("messages"), }, returns: v.object({ text: v.optional(v.string()), embedding: v.optional(v.array(v.number())), embeddingModel: v.optional(v.string()), }), handler: async (ctx, args) => { const message = await ctx.db.get(args.messageId); const text = message?.text; let embedding = undefined; let embeddingModel = undefined; if (message?.embeddingId) { const target = await ctx.db.get(message.embeddingId); embedding = target?.vector; embeddingModel = target?.model; } return { text, embedding, embeddingModel, }; }, }); //# sourceMappingURL=messages.js.map