@convex-dev/agent
Version:
A agent component for Convex.
572 lines • 23.5 kB
JavaScript
import { assert, omit, pick } from "convex-helpers";
import { mergedStream, stream } from "convex-helpers/server/stream";
import { paginationOptsValidator, } from "convex/server";
import { DEFAULT_MESSAGE_RANGE, DEFAULT_RECENT_MESSAGES, extractText, isTool, } from "../shared.js";
import { vMessageEmbeddingsWithDimension, vMessageStatus, vMessageWithMetadataInternal, vPaginationResult, } from "../validators.js";
import { api, internal } from "./_generated/api.js";
import { action, internalQuery, mutation, query, } from "./_generated/server.js";
import { schema, v, vMessageDoc } from "./schema.js";
import { insertVector, searchVectors } from "./vector/index.js";
import { VectorDimensions, vVectorId, } from "./vector/tables.js";
import { changeRefcount } from "./files.js";
import { getStreamingMessagesWithMetadata } from "./streams.js";
import { partial } from "convex-helpers/validators";
function publicMessage(message) {
return omit(message, ["parentMessageId", "stepId", "files"]);
}
export async function deleteMessage(ctx, messageDoc) {
await ctx.db.delete(messageDoc._id);
if (messageDoc.embeddingId) {
await ctx.db.delete(messageDoc.embeddingId);
}
if (messageDoc.fileIds) {
await changeRefcount(ctx, messageDoc.fileIds, []);
}
}
export const deleteByIds = mutation({
args: { messageIds: v.array(v.id("messages")) },
returns: v.array(v.id("messages")),
handler: async (ctx, args) => {
const deletedMessageIds = await Promise.all(args.messageIds.map(async (id) => {
const message = await ctx.db.get(id);
if (message) {
await deleteMessage(ctx, message);
return id;
}
return null;
}));
return deletedMessageIds.filter((id) => id !== null);
},
});
export const messageStatuses = vMessageDoc.fields.status.members.map((m) => m.value);
export const deleteByOrder = mutation({
args: {
threadId: v.id("threads"),
startOrder: v.number(),
startStepOrder: v.optional(v.number()),
endOrder: v.number(),
endStepOrder: v.optional(v.number()),
},
returns: v.object({
isDone: v.boolean(),
lastOrder: v.optional(v.number()),
lastStepOrder: v.optional(v.number()),
}),
handler: async (ctx, args) => {
const messages = await orderedMessagesStream(ctx, args.threadId, "asc", args.startOrder)
.narrow({
lowerBound: args.startStepOrder
? [args.startOrder, args.startStepOrder]
: [args.startOrder],
lowerBoundInclusive: true,
upperBound: args.endStepOrder
? [args.endOrder, args.endStepOrder]
: [args.endOrder],
upperBoundInclusive: false,
})
.take(64);
await Promise.all(messages.map((m) => deleteMessage(ctx, m)));
return {
isDone: messages.length < 64,
lastOrder: messages.at(-1)?.order,
lastStepOrder: messages.at(-1)?.stepOrder,
};
},
});
const addMessagesArgs = {
userId: v.optional(v.string()),
threadId: v.id("threads"),
promptMessageId: v.optional(v.id("messages")),
agentName: v.optional(v.string()),
messages: v.array(vMessageWithMetadataInternal),
embeddings: v.optional(vMessageEmbeddingsWithDimension),
failPendingSteps: v.optional(v.boolean()),
// A pending message to update. If the pending message failed, abort.
pendingMessageId: v.optional(v.id("messages")),
};
export const addMessages = mutation({
args: addMessagesArgs,
handler: addMessagesHandler,
returns: v.object({ messages: v.array(vMessageDoc) }),
});
async function addMessagesHandler(ctx, args) {
let userId = args.userId;
const threadId = args.threadId;
if (!userId && args.threadId) {
const thread = await ctx.db.get(args.threadId);
assert(thread, `Thread ${args.threadId} not found`);
userId = thread.userId;
}
const { embeddings, failPendingSteps, messages, promptMessageId, pendingMessageId, ...rest } = args;
const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId));
if (failPendingSteps) {
assert(args.threadId, "threadId is required to fail pending steps");
const pendingMessages = await ctx.db
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => q.eq("threadId", threadId).eq("status", "pending"))
.order("desc")
.take(100);
await Promise.all(pendingMessages
.filter((m) => !promptMessage || m.order === promptMessage.order)
.filter((m) => !pendingMessageId || m._id !== pendingMessageId)
.map(async (m) => {
if (m.embeddingId) {
await ctx.db.delete(m.embeddingId);
}
await ctx.db.patch(m._id, {
status: "failed",
error: "Restarting",
embeddingId: undefined,
});
}));
}
let order, stepOrder;
let fail = false;
let error;
if (promptMessageId) {
assert(promptMessage, `Parent message ${promptMessageId} not found`);
if (promptMessage.status === "failed") {
fail = true;
error = promptMessage.error ?? error ?? "The prompt message failed";
}
order = promptMessage.order;
// Defend against there being existing messages with this parent.
const maxMessage = await getMaxMessage(ctx, threadId, order);
stepOrder = maxMessage?.stepOrder ?? promptMessage.stepOrder;
}
else {
const maxMessage = await getMaxMessage(ctx, threadId);
order = maxMessage?.order ?? -1;
stepOrder = maxMessage?.stepOrder ?? -1;
}
const toReturn = [];
if (embeddings) {
assert(embeddings.vectors.length === messages.length, "embeddings.vectors.length must match messages.length");
}
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
let embeddingId;
if (embeddings &&
embeddings.vectors[i] &&
!fail &&
message.status !== "failed") {
embeddingId = await insertVector(ctx, embeddings.dimension, {
vector: embeddings.vectors[i],
model: embeddings.model,
table: "messages",
userId,
threadId,
});
}
const messageDoc = {
...rest,
...message,
embeddingId,
parentMessageId: promptMessageId,
userId,
tool: isTool(message.message),
text: extractText(message.message),
status: fail ? "failed" : (message.status ?? "success"),
error: fail ? error : message.error,
};
// If there is a pending message, we replace that one with the first message
// and subsequent ones will follow the regular order/subOrder advancement.
if (i === 0 && pendingMessageId) {
const pendingMessage = await ctx.db.get(pendingMessageId);
assert(pendingMessage, `Pending msg ${pendingMessageId} not found`);
if (pendingMessage.status === "failed") {
fail = true;
error =
`Trying to update a message that failed: ${pendingMessageId}, ` +
`error: ${pendingMessage.error ?? error}`;
messageDoc.status = "failed";
messageDoc.error = error;
}
if (message.fileIds) {
await changeRefcount(ctx, pendingMessage.fileIds ?? [], message.fileIds);
}
await ctx.db.replace(pendingMessage._id, {
...messageDoc,
order: pendingMessage.order,
stepOrder: pendingMessage.stepOrder,
});
toReturn.push(pendingMessage);
continue;
}
if (message.message.role === "user") {
if (promptMessage && promptMessage.order === order) {
// see if there's a later message than the parent message order
const maxMessage = await getMaxMessage(ctx, threadId);
order = (maxMessage?.order ?? order) + 1;
}
else {
order++;
}
stepOrder = 0;
}
else {
if (order < 0) {
order = 0;
}
stepOrder++;
}
const messageId = await ctx.db.insert("messages", {
...messageDoc,
order,
stepOrder,
});
if (message.fileIds) {
await changeRefcount(ctx, [], message.fileIds);
}
// TODO: delete the associated stream data for the order/stepOrder
toReturn.push((await ctx.db.get(messageId)));
}
return { messages: toReturn.map(publicMessage) };
}
// exported for tests
export async function getMaxMessage(ctx, threadId, order) {
return orderedMessagesStream(ctx, threadId, "desc", order).first();
}
function orderedMessagesStream(ctx, threadId, sortOrder, order) {
return mergedStream([true, false].flatMap((tool) => messageStatuses.map((status) => stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", threadId)
.eq("status", status)
.eq("tool", tool);
if (order !== undefined) {
return qq.eq("order", order);
}
return qq;
})
.order(sortOrder))), ["order", "stepOrder"]);
}
export const finalizeMessage = mutation({
args: {
messageId: v.id("messages"),
result: v.union(v.object({ status: v.literal("success") }), v.object({ status: v.literal("failed"), error: v.string() })),
},
returns: v.null(),
handler: async (ctx, { messageId, result }) => {
const message = await ctx.db.get(messageId);
assert(message, `Message ${messageId} not found`);
if (message.status !== "pending") {
console.log("Trying to finalize a message that's already", message.status);
return;
}
// See if we can add any in-progress data
if (message.message === undefined) {
const messages = await getStreamingMessagesWithMetadata(ctx, message, result);
if (messages.length > 0) {
await addMessagesHandler(ctx, {
messages,
threadId: message.threadId,
agentName: message.agentName,
failPendingSteps: false,
pendingMessageId: messageId,
userId: message.userId,
embeddings: undefined,
});
return;
}
}
if (result.status === "failed") {
if (message.embeddingId) {
await ctx.db.delete(message.embeddingId);
}
await ctx.db.patch(messageId, {
status: "failed",
error: result.error,
embeddingId: undefined,
});
}
else {
await ctx.db.patch(messageId, { status: "success" });
}
},
});
export const updateMessage = mutation({
args: {
messageId: v.id("messages"),
patch: v.object(partial(pick(schema.tables.messages.validator.fields, [
"message",
"fileIds",
"status",
"error",
"model",
"provider",
"providerOptions",
"finishReason",
]))),
},
returns: vMessageDoc,
handler: async (ctx, args) => {
const message = await ctx.db.get(args.messageId);
assert(message, `Message ${args.messageId} not found`);
if (args.patch.fileIds) {
await changeRefcount(ctx, message.fileIds ?? [], args.patch.fileIds);
}
const patch = { ...args.patch };
if (args.patch.message !== undefined) {
patch.message = args.patch.message;
patch.tool = isTool(args.patch.message);
patch.text = extractText(args.patch.message);
}
if (args.patch.status === "failed") {
if (message.embeddingId) {
await ctx.db.delete(message.embeddingId);
}
patch.embeddingId = undefined;
}
await ctx.db.patch(args.messageId, patch);
return publicMessage((await ctx.db.get(args.messageId)));
},
});
export const listMessagesByThreadId = query({
args: {
threadId: v.id("threads"),
excludeToolMessages: v.optional(v.boolean()),
/** What order to sort the messages in. To get the latest, use "desc". */
order: v.union(v.literal("asc"), v.literal("desc")),
paginationOpts: v.optional(paginationOptsValidator),
statuses: v.optional(v.array(vMessageStatus)),
upToAndIncludingMessageId: v.optional(v.id("messages")),
},
handler: async (ctx, args) => {
const statuses = args.statuses ?? vMessageStatus.members.map((m) => m.value);
const last = args.upToAndIncludingMessageId &&
(await ctx.db.get(args.upToAndIncludingMessageId));
assert(!last || last.threadId === args.threadId, "upToAndIncludingMessageId must be a message in the thread");
const toolOptions = args.excludeToolMessages ? [false] : [true, false];
const order = args.order ?? "desc";
const streams = toolOptions.flatMap((tool) => statuses.map((status) => stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", args.threadId)
.eq("status", status)
.eq("tool", tool);
if (last) {
return qq.lte("order", last.order);
}
return qq;
})
.order(order)
.filterWith(
// We allow all messages on the same order.
async (m) => !last || m.order < last.order || m.order === last.order)));
const messages = await mergedStream(streams, [
"order",
"stepOrder",
]).paginate(args.paginationOpts ?? {
numItems: DEFAULT_RECENT_MESSAGES,
cursor: null,
});
return { ...messages, page: messages.page.map(publicMessage) };
},
returns: vPaginationResult(vMessageDoc),
});
export const getMessagesByIds = query({
args: { messageIds: v.array(v.id("messages")) },
handler: async (ctx, args) => {
return (await Promise.all(args.messageIds.map((id) => ctx.db.get(id)))).map((m) => (m ? publicMessage(m) : null));
},
returns: v.array(v.union(v.null(), vMessageDoc)),
});
export const searchMessages = action({
args: {
threadId: v.optional(v.id("threads")),
searchAllMessagesForUserId: v.optional(v.string()),
beforeMessageId: v.optional(v.id("messages")),
embedding: v.optional(v.array(v.number())),
embeddingModel: v.optional(v.string()),
text: v.optional(v.string()),
limit: v.number(),
vectorScoreThreshold: v.optional(v.number()),
messageRange: v.optional(v.object({ before: v.number(), after: v.number() })),
},
returns: v.array(vMessageDoc),
handler: async (ctx, args) => {
assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId");
const limit = args.limit;
let textSearchMessages;
if (args.text) {
textSearchMessages = await ctx.runQuery(api.messages.textSearch, {
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
text: args.text,
limit,
beforeMessageId: args.beforeMessageId,
});
}
if (args.embedding) {
const dimension = args.embedding.length;
if (!VectorDimensions.includes(dimension)) {
throw new Error(`Unsupported embedding dimension: ${dimension}`);
}
const vectors = (await searchVectors(ctx, args.embedding, {
dimension,
model: args.embeddingModel ?? "unknown",
table: "messages",
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
limit,
})).filter((v) => v._score > (args.vectorScoreThreshold ?? 0));
// Reciprocal rank fusion
const k = 10;
const textEmbeddingIds = textSearchMessages?.map((m) => m.embeddingId);
const vectorScores = vectors
.map((v, i) => ({
id: v._id,
score: 1 / (i + k) +
1 / ((textEmbeddingIds?.indexOf(v._id) ?? Infinity) + k),
}))
.sort((a, b) => b.score - a.score);
const embeddingIds = vectorScores.slice(0, limit).map((v) => v.id);
const messages = await ctx.runQuery(internal.messages._fetchSearchMessages, {
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
embeddingIds,
textSearchMessages: textSearchMessages?.filter((m) => !embeddingIds.includes(m.embeddingId)),
messageRange: args.messageRange ?? DEFAULT_MESSAGE_RANGE,
beforeMessageId: args.beforeMessageId,
limit,
});
return messages;
}
return textSearchMessages?.flat() ?? [];
},
});
export const _fetchSearchMessages = internalQuery({
args: {
threadId: v.optional(v.id("threads")),
embeddingIds: v.array(vVectorId),
searchAllMessagesForUserId: v.optional(v.string()),
textSearchMessages: v.optional(v.array(vMessageDoc)),
messageRange: v.object({ before: v.number(), after: v.number() }),
beforeMessageId: v.optional(v.id("messages")),
limit: v.number(),
},
returns: v.array(vMessageDoc),
handler: async (ctx, args) => {
const beforeMessage = args.beforeMessageId && (await ctx.db.get(args.beforeMessageId));
const { searchAllMessagesForUserId, threadId } = args;
assert(searchAllMessagesForUserId || threadId, "Specify searchAllMessagesForUserId or threadId to search");
let messages = (await Promise.all(args.embeddingIds.map((embeddingId) => ctx.db
.query("messages")
.withIndex("embeddingId_threadId", (q) => searchAllMessagesForUserId
? q.eq("embeddingId", embeddingId)
: q.eq("embeddingId", embeddingId).eq("threadId", threadId))
.filter((q) => q.and(q.eq(q.field("status"), "success"), searchAllMessagesForUserId
? q.eq(q.field("userId"), searchAllMessagesForUserId)
: q.eq(q.field("threadId"), threadId)))
.first())))
.filter((m) => m !== undefined &&
m !== null &&
!m.tool &&
(!beforeMessage ||
m.order < beforeMessage.order ||
(m.order === beforeMessage.order &&
m.stepOrder < beforeMessage.stepOrder)))
.map(publicMessage);
messages.push(...(args.textSearchMessages ?? []));
// TODO: prioritize more recent messages
messages.sort((a, b) => a.order - b.order);
messages = messages.slice(0, args.limit);
// Fetch the surrounding messages
if (!threadId) {
return messages.sort((a, b) => a.order - b.order);
}
const included = {};
for (const m of messages) {
const searchId = m.threadId ?? m.userId;
if (!included[searchId]) {
included[searchId] = new Set();
}
included[searchId].add(m.order);
}
const ranges = {};
const { before, after } = args.messageRange;
for (const m of messages) {
const searchId = m.threadId ?? m.userId;
const order = m.order;
let earliest = order - before;
let latest = order + after;
for (; earliest <= latest; earliest++) {
if (!included[searchId].has(earliest)) {
break;
}
}
for (; latest >= earliest; latest--) {
if (!included[searchId].has(latest)) {
break;
}
}
for (let i = earliest; i <= latest; i++) {
included[searchId].add(i);
}
if (earliest !== latest) {
const surrounding = await ctx.db
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => q
.eq("threadId", m.threadId)
.eq("status", "success")
.eq("tool", false)
.gte("order", earliest)
.lte("order", latest))
.collect();
if (!ranges[searchId]) {
ranges[searchId] = [];
}
ranges[searchId].push(...surrounding);
}
}
for (const r of Object.values(ranges).flat()) {
if (!messages.some((m) => m._id === r._id)) {
messages.push(publicMessage(r));
}
}
return messages.sort((a, b) => a.order - b.order);
},
});
// returns ranges of messages in order of text search relevance,
// excluding duplicates in later ranges.
export const textSearch = query({
args: {
threadId: v.optional(v.id("threads")),
searchAllMessagesForUserId: v.optional(v.string()),
text: v.string(),
limit: v.number(),
beforeMessageId: v.optional(v.id("messages")),
},
handler: async (ctx, args) => {
assert(args.searchAllMessagesForUserId || args.threadId, "Specify userId or threadId");
const beforeMessage = args.beforeMessageId && (await ctx.db.get(args.beforeMessageId));
const order = beforeMessage?.order;
const messages = await ctx.db
.query("messages")
.withSearchIndex("text_search", (q) => args.searchAllMessagesForUserId
? q
.search("text", args.text)
.eq("userId", args.searchAllMessagesForUserId)
: q.search("text", args.text).eq("threadId", args.threadId))
// Just in case tool messages slip through
.filter((q) => {
const qq = q.eq(q.field("tool"), false);
if (order) {
return q.and(qq, q.lte(q.field("order"), order));
}
return qq;
})
.take(args.limit);
return messages
.filter((m) => !beforeMessage ||
m.order < beforeMessage.order ||
(m.order === beforeMessage.order &&
m.stepOrder < beforeMessage.stepOrder))
.map(publicMessage);
},
returns: v.array(vMessageDoc),
});
//# sourceMappingURL=messages.js.map