UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

378 lines 15.4 kB
import { embedMany as embedMany_, } from "ai"; import { assert } from "convex-helpers"; import { validateVectorDimension, } from "../component/vector/tables.js"; import { DEFAULT_MESSAGE_RANGE, DEFAULT_RECENT_MESSAGES, extractText, getModelName, getProviderName, isTool, sorted, } from "../shared.js"; import { inlineMessagesFiles } from "./files.js"; import { docsToModelMessages, toModelMessage } from "../mapping.js"; const DEFAULT_VECTOR_SCORE_THRESHOLD = 0.0; // 10k characters should be more than enough for most cases, and stays under // the 8k token limit for some models. const MAX_EMBEDDING_TEXT_LENGTH = 10_000; /** * Fetch the context messages for a thread. * @param ctx Either a query, mutation, or action ctx. * If it is not an action context, you can't do text or * vector search. * @param args The associated thread, user, message * @returns */ export async function fetchContextMessages(ctx, component, args) { const { recentMessages, searchMessages } = await fetchRecentAndSearchMessages(ctx, component, args); return [...searchMessages, ...recentMessages]; } export async function fetchRecentAndSearchMessages(ctx, component, args) { assert(args.userId || args.threadId, "Specify userId or threadId"); const opts = args.contextOptions; // Fetch the latest messages from the thread let included; let recentMessages = []; let searchMessages = []; const targetMessageId = args.targetMessageId ?? args.upToAndIncludingMessageId; if (args.threadId && opts.recentMessages !== 0) { const { page } = await ctx.runQuery(component.messages.listMessagesByThreadId, { threadId: args.threadId, excludeToolMessages: opts.excludeToolMessages, paginationOpts: { numItems: opts.recentMessages ?? DEFAULT_RECENT_MESSAGES, cursor: null, }, upToAndIncludingMessageId: targetMessageId, order: "desc", statuses: ["success"], }); included = new Set(page.map((m) => m._id)); recentMessages = filterOutOrphanedToolMessages(sorted(page)); } if ((opts.searchOptions?.textSearch || opts.searchOptions?.vectorSearch) && opts.searchOptions?.limit) { if (!("runAction" in ctx)) { throw new Error("searchUserMessages only works in an action"); } let text = args.searchText; let embedding; let embeddingModel; if (!text) { if (targetMessageId) { const targetMessage = recentMessages.find((m) => m._id === targetMessageId); if (targetMessage) { text = targetMessage.text; } else { const targetSearchFields = await ctx.runQuery(component.messages.getMessageSearchFields, { messageId: targetMessageId, }); text = targetSearchFields.text; embedding = targetSearchFields.embedding; embeddingModel = targetSearchFields.embeddingModel; } assert(text, "Target message has no text for searching"); } else if (args.messages?.length) { text = extractText(args.messages.at(-1)); assert(text, "Final context message has no text to search"); } assert(text, "No text to search"); } if (opts.searchOptions?.vectorSearch) { if (!embedding && args.getEmbedding) { const embeddingFields = await args.getEmbedding(text); embedding = embeddingFields.embedding; embeddingModel = embeddingFields.textEmbeddingModel ? getModelName(embeddingFields.textEmbeddingModel) : undefined; // TODO: if the text matches the target message, save the embedding // for the target message and return the embeddingId on the message. } } const searchResults = await ctx.runAction(component.messages.searchMessages, { searchAllMessagesForUserId: opts?.searchOtherThreads ? (args.userId ?? (args.threadId && (await ctx.runQuery(component.threads.getThread, { threadId: args.threadId, }))?.userId)) : undefined, threadId: args.threadId, targetMessageId, limit: opts.searchOptions?.limit ?? 10, messageRange: { ...DEFAULT_MESSAGE_RANGE, ...opts.searchOptions?.messageRange, }, text, textSearch: opts.searchOptions?.textSearch, vectorSearch: opts.searchOptions?.vectorSearch, vectorScoreThreshold: opts.searchOptions?.vectorScoreThreshold ?? DEFAULT_VECTOR_SCORE_THRESHOLD, embedding, embeddingModel, }); // TODO: track what messages we used for context searchMessages = filterOutOrphanedToolMessages(sorted(searchResults.filter((m) => !included?.has(m._id)))); } // Ensure we don't include tool messages without a corresponding tool call return { recentMessages, searchMessages }; } /** * Filter out tool messages that don't have both a tool call and response. * @param docs The messages to filter. * @returns The filtered messages. */ export function filterOutOrphanedToolMessages(docs) { const toolCallIds = new Set(); const toolResultIds = new Set(); const result = []; for (const doc of docs) { if (doc.message && Array.isArray(doc.message.content)) { for (const content of doc.message.content) { if (content.type === "tool-call") { toolCallIds.add(content.toolCallId); } else if (content.type === "tool-result") { toolResultIds.add(content.toolCallId); } } } } for (const doc of docs) { if (doc.message?.role === "assistant" && Array.isArray(doc.message.content)) { const content = doc.message.content.filter((p) => p.type !== "tool-call" || toolResultIds.has(p.toolCallId)); if (content.length) { result.push({ ...doc, message: { ...doc.message, content, }, }); } } else if (doc.message?.role === "tool") { const content = doc.message.content.filter((c) => toolCallIds.has(c.toolCallId)); if (content.length) { result.push({ ...doc, message: { ...doc.message, content, }, }); } } else { result.push(doc); } } return result; } /** * Embed a list of messages, including calling any usage handler. * This will not save the embeddings to the database. */ export async function embedMessages(ctx, { userId, threadId, ...options }, messages) { if (!options.textEmbeddingModel) { return undefined; } let embeddings; const messageTexts = messages.map((m) => !isTool(m) && extractText(m)); // Find the indexes of the messages that have text. const textIndexes = messageTexts .map((t, i) => (t ? i : undefined)) .filter((i) => i !== undefined); if (textIndexes.length === 0) { return undefined; } const values = messageTexts .map((t) => t && t.trim().slice(0, MAX_EMBEDDING_TEXT_LENGTH)) .filter((t) => !!t); // Then embed those messages. const textEmbeddings = await embedMany(ctx, { ...options, userId, threadId, values, }); // Then assemble the embeddings into a single array with nulls for the messages without text. const embeddingsOrNull = Array(messages.length).fill(null); textIndexes.forEach((i, j) => { embeddingsOrNull[i] = textEmbeddings.embeddings[j]; }); if (textEmbeddings.embeddings.length > 0) { const dimension = textEmbeddings.embeddings[0].length; validateVectorDimension(dimension); const model = getModelName(options.textEmbeddingModel); embeddings = { vectors: embeddingsOrNull, dimension, model }; } return embeddings; } /** * Embeds many strings, calling any usage handler. * @param ctx The ctx parameter to an action. * @param args Arguments to AI SDK's embedMany, and context for the embedding, * passed to the usage handler. * @returns The embeddings for the strings, matching the order of the values. */ export async function embedMany(ctx, { userId, threadId, values, abortSignal, headers, agentName, usageHandler, textEmbeddingModel, callSettings, }) { const embeddingModel = textEmbeddingModel; assert(embeddingModel, "a textEmbeddingModel is required to be set for vector search"); const result = await embedMany_({ ...callSettings, model: embeddingModel, values, abortSignal, headers, }); if (usageHandler && result.usage) { await usageHandler(ctx, { userId, threadId, agentName, model: getModelName(embeddingModel), provider: getProviderName(embeddingModel), providerMetadata: undefined, usage: { inputTokens: result.usage.tokens, outputTokens: 0, totalTokens: result.usage.tokens, }, }); } return { embeddings: result.embeddings }; } /** * Embed a list of messages, and save the embeddings to the database. * @param ctx The ctx parameter to an action. * @param component The agent component, usually components.agent. * @param args The context for the embedding, passed to the usage handler. * @param messages The messages to embed, in the Agent MessageDoc format. */ export async function generateAndSaveEmbeddings(ctx, component, args, messages) { const toEmbed = messages.filter((m) => !m.embeddingId && m.message); if (toEmbed.length === 0) { return; } const embeddings = await embedMessages(ctx, args, toEmbed.map((m) => m.message)); if (embeddings && embeddings.vectors.some((v) => v !== null)) { await ctx.runMutation(component.vector.index.insertBatch, { vectorDimension: embeddings.dimension, vectors: toEmbed .map((m, i) => ({ messageId: m._id, model: embeddings.model, table: "messages", userId: m.userId, threadId: m.threadId, vector: embeddings.vectors[i], })) .filter((v) => v.vector !== null), }); } } /** * Similar to fetchContextMessages, but also combines the input messages, * with search context, recent messages, input messages, then prompt messages. * If there is a promptMessageId and prompt message(s) provided, it will splice * the prompt messages into the history to replace the promptMessageId message, * but still be followed by any existing messages that were in response to the * promptMessageId message. */ export async function fetchContextWithPrompt(ctx, component, args) { const { threadId, userId, textEmbeddingModel } = args; const promptArray = getPromptArray(args.prompt); const searchText = promptArray.length ? extractText(promptArray.at(-1)) : args.promptMessageId ? undefined : args.messages?.at(-1) ? extractText(args.messages.at(-1)) : undefined; // If only a messageId is provided, this will add that message to the end. const { recentMessages, searchMessages } = await fetchRecentAndSearchMessages(ctx, component, { userId, threadId, targetMessageId: args.promptMessageId, searchText, contextOptions: args.contextOptions ?? {}, getEmbedding: async (text) => { assert(textEmbeddingModel, "A textEmbeddingModel is required to be set on the Agent that you're doing vector search with"); return { embedding: (await embedMany(ctx, { ...args, userId, values: [text], textEmbeddingModel, })).embeddings[0], textEmbeddingModel, }; }, }); const promptMessageIndex = args.promptMessageId ? recentMessages.findIndex((m) => m._id === args.promptMessageId) : -1; const promptMessage = promptMessageIndex !== -1 ? recentMessages[promptMessageIndex] : undefined; let prePromptDocs = recentMessages; const messages = args.messages ?? []; let existingResponseDocs = []; if (promptMessage) { prePromptDocs = recentMessages.slice(0, promptMessageIndex); existingResponseDocs = recentMessages.slice(promptMessageIndex + 1); if (promptArray.length === 0) { // If they didn't override the prompt, use the existing prompt message. if (promptMessage.message) { promptArray.push(promptMessage.message); } } if (!promptMessage.embeddingId && textEmbeddingModel) { // Lazily generate embeddings for the prompt message, if it doesn't have // embeddings yet. This can happen if the message was saved in a mutation // where the LLM is not available. await generateAndSaveEmbeddings(ctx, component, { ...args, userId, textEmbeddingModel, }, [promptMessage]); } } const search = docsToModelMessages(searchMessages); const recent = docsToModelMessages(prePromptDocs); const inputMessages = messages.map(toModelMessage); const inputPrompt = promptArray.map(toModelMessage); const existingResponses = docsToModelMessages(existingResponseDocs); const allMessages = [ ...search, ...recent, ...inputMessages, ...inputPrompt, ...existingResponses, ]; let processedMessages = args.contextHandler ? await args.contextHandler(ctx, { allMessages, search, recent, inputMessages, inputPrompt, existingResponses, userId, threadId, }) : allMessages; // Process messages to inline localhost files (if not, file urls pointing to localhost will be sent to LLM providers) if (process.env.CONVEX_CLOUD_URL?.startsWith("http://127.0.0.1")) { processedMessages = await inlineMessagesFiles(processedMessages); } return { messages: processedMessages, order: promptMessage?.order, stepOrder: promptMessage?.stepOrder, }; } export function getPromptArray(prompt) { return !prompt ? [] : Array.isArray(prompt) ? prompt : [{ role: "user", content: prompt }]; } //# sourceMappingURL=search.js.map