@convex-dev/agent
Version:
A agent component for Convex.
743 lines (722 loc) • 23 kB
text/typescript
import { assert, omit, pick } from "convex-helpers";
import { mergedStream, stream } from "convex-helpers/server/stream";
import {
paginationOptsValidator,
type WithoutSystemFields,
} from "convex/server";
import type { ObjectType } from "convex/values";
import {
DEFAULT_MESSAGE_RANGE,
DEFAULT_RECENT_MESSAGES,
extractText,
isTool,
} from "../shared.js";
import {
vMessageEmbeddingsWithDimension,
vMessageStatus,
vMessageWithMetadataInternal,
vPaginationResult,
} from "../validators.js";
import { api, internal } from "./_generated/api.js";
import type { Doc, Id } from "./_generated/dataModel.js";
import {
action,
internalQuery,
mutation,
type MutationCtx,
query,
type QueryCtx,
} from "./_generated/server.js";
import type { MessageDoc } from "./schema.js";
import { schema, v, vMessageDoc } from "./schema.js";
import { insertVector, searchVectors } from "./vector/index.js";
import {
type VectorDimension,
VectorDimensions,
type VectorTableId,
vVectorId,
} from "./vector/tables.js";
import { changeRefcount } from "./files.js";
import { getStreamingMessagesWithMetadata } from "./streams.js";
import { partial } from "convex-helpers/validators";
function publicMessage(message: Doc<"messages">): MessageDoc {
return omit(message, ["parentMessageId", "stepId", "files"]);
}
export async function deleteMessage(
ctx: MutationCtx,
messageDoc: Doc<"messages">,
) {
await ctx.db.delete(messageDoc._id);
if (messageDoc.embeddingId) {
await ctx.db.delete(messageDoc.embeddingId);
}
if (messageDoc.fileIds) {
await changeRefcount(ctx, messageDoc.fileIds, []);
}
}
export const deleteByIds = mutation({
args: { messageIds: v.array(v.id("messages")) },
returns: v.array(v.id("messages")),
handler: async (ctx, args) => {
const deletedMessageIds = await Promise.all(
args.messageIds.map(async (id) => {
const message = await ctx.db.get(id);
if (message) {
await deleteMessage(ctx, message);
return id;
}
return null;
}),
);
return deletedMessageIds.filter((id) => id !== null);
},
});
export const messageStatuses = vMessageDoc.fields.status.members.map(
(m) => m.value,
);
export const deleteByOrder = mutation({
args: {
threadId: v.id("threads"),
startOrder: v.number(),
startStepOrder: v.optional(v.number()),
endOrder: v.number(),
endStepOrder: v.optional(v.number()),
},
returns: v.object({
isDone: v.boolean(),
lastOrder: v.optional(v.number()),
lastStepOrder: v.optional(v.number()),
}),
handler: async (ctx, args) => {
const messages = await orderedMessagesStream(
ctx,
args.threadId,
"asc",
args.startOrder,
)
.narrow({
lowerBound: args.startStepOrder
? [args.startOrder, args.startStepOrder]
: [args.startOrder],
lowerBoundInclusive: true,
upperBound: args.endStepOrder
? [args.endOrder, args.endStepOrder]
: [args.endOrder],
upperBoundInclusive: false,
})
.take(64);
await Promise.all(messages.map((m) => deleteMessage(ctx, m)));
return {
isDone: messages.length < 64,
lastOrder: messages.at(-1)?.order,
lastStepOrder: messages.at(-1)?.stepOrder,
};
},
});
const addMessagesArgs = {
userId: v.optional(v.string()),
threadId: v.id("threads"),
promptMessageId: v.optional(v.id("messages")),
agentName: v.optional(v.string()),
messages: v.array(vMessageWithMetadataInternal),
embeddings: v.optional(vMessageEmbeddingsWithDimension),
failPendingSteps: v.optional(v.boolean()),
// A pending message to update. If the pending message failed, abort.
pendingMessageId: v.optional(v.id("messages")),
};
export const addMessages = mutation({
args: addMessagesArgs,
handler: addMessagesHandler,
returns: v.object({ messages: v.array(vMessageDoc) }),
});
async function addMessagesHandler(
ctx: MutationCtx,
args: ObjectType<typeof addMessagesArgs>,
) {
let userId = args.userId;
const threadId = args.threadId;
if (!userId && args.threadId) {
const thread = await ctx.db.get(args.threadId);
assert(thread, `Thread ${args.threadId} not found`);
userId = thread.userId;
}
const {
embeddings,
failPendingSteps,
messages,
promptMessageId,
pendingMessageId,
...rest
} = args;
const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId));
if (failPendingSteps) {
assert(args.threadId, "threadId is required to fail pending steps");
const pendingMessages = await ctx.db
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) =>
q.eq("threadId", threadId).eq("status", "pending"),
)
.order("desc")
.take(100);
await Promise.all(
pendingMessages
.filter((m) => !promptMessage || m.order === promptMessage.order)
.filter((m) => !pendingMessageId || m._id !== pendingMessageId)
.map(async (m) => {
if (m.embeddingId) {
await ctx.db.delete(m.embeddingId);
}
await ctx.db.patch(m._id, {
status: "failed",
error: "Restarting",
embeddingId: undefined,
});
}),
);
}
let order, stepOrder;
let fail = false;
let error: string | undefined;
if (promptMessageId) {
assert(promptMessage, `Parent message ${promptMessageId} not found`);
if (promptMessage.status === "failed") {
fail = true;
error = promptMessage.error ?? error ?? "The prompt message failed";
}
order = promptMessage.order;
// Defend against there being existing messages with this parent.
const maxMessage = await getMaxMessage(ctx, threadId, order);
stepOrder = maxMessage?.stepOrder ?? promptMessage.stepOrder;
} else {
const maxMessage = await getMaxMessage(ctx, threadId);
order = maxMessage?.order ?? -1;
stepOrder = maxMessage?.stepOrder ?? -1;
}
const toReturn: Doc<"messages">[] = [];
if (embeddings) {
assert(
embeddings.vectors.length === messages.length,
"embeddings.vectors.length must match messages.length",
);
}
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
let embeddingId: VectorTableId | undefined;
if (
embeddings &&
embeddings.vectors[i] &&
!fail &&
message.status !== "failed"
) {
embeddingId = await insertVector(ctx, embeddings.dimension, {
vector: embeddings.vectors[i]!,
model: embeddings.model,
table: "messages",
userId,
threadId,
});
}
const messageDoc = {
...rest,
...message,
embeddingId,
parentMessageId: promptMessageId,
userId,
tool: isTool(message.message),
text: extractText(message.message),
status: fail ? "failed" : (message.status ?? "success"),
error: fail ? error : message.error,
} satisfies Omit<
WithoutSystemFields<Doc<"messages">>,
"order" | "stepOrder"
>;
// If there is a pending message, we replace that one with the first message
// and subsequent ones will follow the regular order/subOrder advancement.
if (i === 0 && pendingMessageId) {
const pendingMessage = await ctx.db.get(pendingMessageId);
assert(pendingMessage, `Pending msg ${pendingMessageId} not found`);
if (pendingMessage.status === "failed") {
fail = true;
error =
`Trying to update a message that failed: ${pendingMessageId}, ` +
`error: ${pendingMessage.error ?? error}`;
messageDoc.status = "failed";
messageDoc.error = error;
}
if (message.fileIds) {
await changeRefcount(
ctx,
pendingMessage.fileIds ?? [],
message.fileIds,
);
}
await ctx.db.replace(pendingMessage._id, {
...messageDoc,
order: pendingMessage.order,
stepOrder: pendingMessage.stepOrder,
});
toReturn.push(pendingMessage);
continue;
}
if (message.message.role === "user") {
if (promptMessage && promptMessage.order === order) {
// see if there's a later message than the parent message order
const maxMessage = await getMaxMessage(ctx, threadId);
order = (maxMessage?.order ?? order) + 1;
} else {
order++;
}
stepOrder = 0;
} else {
if (order < 0) {
order = 0;
}
stepOrder++;
}
const messageId = await ctx.db.insert("messages", {
...messageDoc,
order,
stepOrder,
});
if (message.fileIds) {
await changeRefcount(ctx, [], message.fileIds);
}
// TODO: delete the associated stream data for the order/stepOrder
toReturn.push((await ctx.db.get(messageId))!);
}
return { messages: toReturn.map(publicMessage) };
}
// exported for tests
export async function getMaxMessage(
ctx: QueryCtx,
threadId: Id<"threads">,
order?: number,
) {
return orderedMessagesStream(ctx, threadId, "desc", order).first();
}
function orderedMessagesStream(
ctx: QueryCtx,
threadId: Id<"threads">,
sortOrder: "asc" | "desc",
order?: number,
) {
return mergedStream(
[true, false].flatMap((tool) =>
messageStatuses.map((status) =>
stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", threadId)
.eq("status", status)
.eq("tool", tool);
if (order !== undefined) {
return qq.eq("order", order);
}
return qq;
})
.order(sortOrder),
),
),
["order", "stepOrder"],
);
}
export const finalizeMessage = mutation({
args: {
messageId: v.id("messages"),
result: v.union(
v.object({ status: v.literal("success") }),
v.object({ status: v.literal("failed"), error: v.string() }),
),
},
returns: v.null(),
handler: async (ctx, { messageId, result }) => {
const message = await ctx.db.get(messageId);
assert(message, `Message ${messageId} not found`);
if (message.status !== "pending") {
console.log(
"Trying to finalize a message that's already",
message.status,
);
return;
}
// See if we can add any in-progress data
if (message.message === undefined) {
const messages = await getStreamingMessagesWithMetadata(
ctx,
message,
result,
);
if (messages.length > 0) {
await addMessagesHandler(ctx, {
messages,
threadId: message.threadId,
agentName: message.agentName,
failPendingSteps: false,
pendingMessageId: messageId,
userId: message.userId,
embeddings: undefined,
});
return;
}
}
if (result.status === "failed") {
if (message.embeddingId) {
await ctx.db.delete(message.embeddingId);
}
await ctx.db.patch(messageId, {
status: "failed",
error: result.error,
embeddingId: undefined,
});
} else {
await ctx.db.patch(messageId, { status: "success" });
}
},
});
export const updateMessage = mutation({
args: {
messageId: v.id("messages"),
patch: v.object(
partial(
pick(schema.tables.messages.validator.fields, [
"message",
"fileIds",
"status",
"error",
"model",
"provider",
"providerOptions",
"finishReason",
]),
),
),
},
returns: vMessageDoc,
handler: async (ctx, args) => {
const message = await ctx.db.get(args.messageId);
assert(message, `Message ${args.messageId} not found`);
if (args.patch.fileIds) {
await changeRefcount(ctx, message.fileIds ?? [], args.patch.fileIds);
}
const patch: Partial<Doc<"messages">> = { ...args.patch };
if (args.patch.message !== undefined) {
patch.message = args.patch.message;
patch.tool = isTool(args.patch.message);
patch.text = extractText(args.patch.message);
}
if (args.patch.status === "failed") {
if (message.embeddingId) {
await ctx.db.delete(message.embeddingId);
}
patch.embeddingId = undefined;
}
await ctx.db.patch(args.messageId, patch);
return publicMessage((await ctx.db.get(args.messageId))!);
},
});
export const listMessagesByThreadId = query({
args: {
threadId: v.id("threads"),
excludeToolMessages: v.optional(v.boolean()),
/** What order to sort the messages in. To get the latest, use "desc". */
order: v.union(v.literal("asc"), v.literal("desc")),
paginationOpts: v.optional(paginationOptsValidator),
statuses: v.optional(v.array(vMessageStatus)),
upToAndIncludingMessageId: v.optional(v.id("messages")),
},
handler: async (ctx, args) => {
const statuses =
args.statuses ?? vMessageStatus.members.map((m) => m.value);
const last =
args.upToAndIncludingMessageId &&
(await ctx.db.get(args.upToAndIncludingMessageId));
assert(
!last || last.threadId === args.threadId,
"upToAndIncludingMessageId must be a message in the thread",
);
const toolOptions = args.excludeToolMessages ? [false] : [true, false];
const order = args.order ?? "desc";
const streams = toolOptions.flatMap((tool) =>
statuses.map((status) =>
stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", args.threadId)
.eq("status", status)
.eq("tool", tool);
if (last) {
return qq.lte("order", last.order);
}
return qq;
})
.order(order)
.filterWith(
// We allow all messages on the same order.
async (m) =>
!last || m.order < last.order || m.order === last.order,
),
),
);
const messages = await mergedStream(streams, [
"order",
"stepOrder",
]).paginate(
args.paginationOpts ?? {
numItems: DEFAULT_RECENT_MESSAGES,
cursor: null,
},
);
return { ...messages, page: messages.page.map(publicMessage) };
},
returns: vPaginationResult(vMessageDoc),
});
export const getMessagesByIds = query({
args: { messageIds: v.array(v.id("messages")) },
handler: async (ctx, args) => {
return (await Promise.all(args.messageIds.map((id) => ctx.db.get(id)))).map(
(m) => (m ? publicMessage(m) : null),
);
},
returns: v.array(v.union(v.null(), vMessageDoc)),
});
export const searchMessages = action({
args: {
threadId: v.optional(v.id("threads")),
searchAllMessagesForUserId: v.optional(v.string()),
beforeMessageId: v.optional(v.id("messages")),
embedding: v.optional(v.array(v.number())),
embeddingModel: v.optional(v.string()),
text: v.optional(v.string()),
limit: v.number(),
vectorScoreThreshold: v.optional(v.number()),
messageRange: v.optional(
v.object({ before: v.number(), after: v.number() }),
),
},
returns: v.array(vMessageDoc),
handler: async (ctx, args): Promise<MessageDoc[]> => {
assert(
args.searchAllMessagesForUserId || args.threadId,
"Specify userId or threadId",
);
const limit = args.limit;
let textSearchMessages: MessageDoc[] | undefined;
if (args.text) {
textSearchMessages = await ctx.runQuery(api.messages.textSearch, {
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
text: args.text,
limit,
beforeMessageId: args.beforeMessageId,
});
}
if (args.embedding) {
const dimension = args.embedding.length as VectorDimension;
if (!VectorDimensions.includes(dimension)) {
throw new Error(`Unsupported embedding dimension: ${dimension}`);
}
const vectors = (
await searchVectors(ctx, args.embedding, {
dimension,
model: args.embeddingModel ?? "unknown",
table: "messages",
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
limit,
})
).filter((v) => v._score > (args.vectorScoreThreshold ?? 0));
// Reciprocal rank fusion
const k = 10;
const textEmbeddingIds = textSearchMessages?.map((m) => m.embeddingId);
const vectorScores = vectors
.map((v, i) => ({
id: v._id,
score:
1 / (i + k) +
1 / ((textEmbeddingIds?.indexOf(v._id) ?? Infinity) + k),
}))
.sort((a, b) => b.score - a.score);
const embeddingIds = vectorScores.slice(0, limit).map((v) => v.id);
const messages: MessageDoc[] = await ctx.runQuery(
internal.messages._fetchSearchMessages,
{
searchAllMessagesForUserId: args.searchAllMessagesForUserId,
threadId: args.threadId,
embeddingIds,
textSearchMessages: textSearchMessages?.filter(
(m) => !embeddingIds.includes(m.embeddingId! as VectorTableId),
),
messageRange: args.messageRange ?? DEFAULT_MESSAGE_RANGE,
beforeMessageId: args.beforeMessageId,
limit,
},
);
return messages;
}
return textSearchMessages?.flat() ?? [];
},
});
export const _fetchSearchMessages = internalQuery({
args: {
threadId: v.optional(v.id("threads")),
embeddingIds: v.array(vVectorId),
searchAllMessagesForUserId: v.optional(v.string()),
textSearchMessages: v.optional(v.array(vMessageDoc)),
messageRange: v.object({ before: v.number(), after: v.number() }),
beforeMessageId: v.optional(v.id("messages")),
limit: v.number(),
},
returns: v.array(vMessageDoc),
handler: async (ctx, args): Promise<MessageDoc[]> => {
const beforeMessage =
args.beforeMessageId && (await ctx.db.get(args.beforeMessageId));
const { searchAllMessagesForUserId, threadId } = args;
assert(
searchAllMessagesForUserId || threadId,
"Specify searchAllMessagesForUserId or threadId to search",
);
let messages: MessageDoc[] = (
await Promise.all(
args.embeddingIds.map((embeddingId) =>
ctx.db
.query("messages")
.withIndex("embeddingId_threadId", (q) =>
searchAllMessagesForUserId
? q.eq("embeddingId", embeddingId)
: q.eq("embeddingId", embeddingId).eq("threadId", threadId!),
)
.filter((q) =>
q.and(
q.eq(q.field("status"), "success"),
searchAllMessagesForUserId
? q.eq(q.field("userId"), searchAllMessagesForUserId)
: q.eq(q.field("threadId"), threadId),
),
)
.first(),
),
)
)
.filter(
(m): m is Doc<"messages"> =>
m !== undefined &&
m !== null &&
!m.tool &&
(!beforeMessage ||
m.order < beforeMessage.order ||
(m.order === beforeMessage.order &&
m.stepOrder < beforeMessage.stepOrder)),
)
.map(publicMessage);
messages.push(...(args.textSearchMessages ?? []));
// TODO: prioritize more recent messages
messages.sort((a, b) => a.order! - b.order!);
messages = messages.slice(0, args.limit);
// Fetch the surrounding messages
if (!threadId) {
return messages.sort((a, b) => a.order - b.order);
}
const included: Record<string, Set<number>> = {};
for (const m of messages) {
const searchId = m.threadId ?? m.userId!;
if (!included[searchId]) {
included[searchId] = new Set();
}
included[searchId].add(m.order!);
}
const ranges: Record<string, Doc<"messages">[]> = {};
const { before, after } = args.messageRange;
for (const m of messages) {
const searchId = m.threadId ?? m.userId!;
const order = m.order!;
let earliest = order - before;
let latest = order + after;
for (; earliest <= latest; earliest++) {
if (!included[searchId].has(earliest)) {
break;
}
}
for (; latest >= earliest; latest--) {
if (!included[searchId].has(latest)) {
break;
}
}
for (let i = earliest; i <= latest; i++) {
included[searchId].add(i);
}
if (earliest !== latest) {
const surrounding = await ctx.db
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) =>
q
.eq("threadId", m.threadId as Id<"threads">)
.eq("status", "success")
.eq("tool", false)
.gte("order", earliest)
.lte("order", latest),
)
.collect();
if (!ranges[searchId]) {
ranges[searchId] = [];
}
ranges[searchId].push(...surrounding);
}
}
for (const r of Object.values(ranges).flat()) {
if (!messages.some((m) => m._id === r._id)) {
messages.push(publicMessage(r));
}
}
return messages.sort((a, b) => a.order - b.order);
},
});
// returns ranges of messages in order of text search relevance,
// excluding duplicates in later ranges.
export const textSearch = query({
args: {
threadId: v.optional(v.id("threads")),
searchAllMessagesForUserId: v.optional(v.string()),
text: v.string(),
limit: v.number(),
beforeMessageId: v.optional(v.id("messages")),
},
handler: async (ctx, args) => {
assert(
args.searchAllMessagesForUserId || args.threadId,
"Specify userId or threadId",
);
const beforeMessage =
args.beforeMessageId && (await ctx.db.get(args.beforeMessageId));
const order = beforeMessage?.order;
const messages = await ctx.db
.query("messages")
.withSearchIndex("text_search", (q) =>
args.searchAllMessagesForUserId
? q
.search("text", args.text)
.eq("userId", args.searchAllMessagesForUserId)
: q.search("text", args.text).eq("threadId", args.threadId!),
)
// Just in case tool messages slip through
.filter((q) => {
const qq = q.eq(q.field("tool"), false);
if (order) {
return q.and(qq, q.lte(q.field("order"), order));
}
return qq;
})
.take(args.limit);
return messages
.filter(
(m) =>
!beforeMessage ||
m.order < beforeMessage.order ||
(m.order === beforeMessage.order &&
m.stepOrder < beforeMessage.stepOrder),
)
.map(publicMessage);
},
returns: v.array(vMessageDoc),
});