@convex-dev/agent
Version:
A agent component for Convex.
238 lines (230 loc) • 6.83 kB
text/typescript
import type { ModelMessage } from "ai";
import type { PaginationOptions, PaginationResult } from "convex/server";
import type { MessageDoc } from "../validators.js";
import { validateVectorDimension } from "../component/vector/tables.js";
import {
vMessageWithMetadata,
type Message,
type MessageEmbeddings,
type MessageEmbeddingsWithDimension,
type MessageStatus,
type MessageWithMetadata,
} from "../validators.js";
import { serializeMessage } from "../mapping.js";
import { toUIMessages, type UIMessage } from "../UIMessages.js";
import type {
AgentComponent,
MutationCtx,
QueryCtx,
ActionCtx,
} from "./types.js";
import { parse } from "convex-helpers/validators";
/**
* List messages from a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param component The agent component, usually `components.agent`.
* @param args.threadId The thread to list messages from.
* @param args.paginationOpts Pagination options (e.g. via usePaginatedQuery).
* @param args.excludeToolMessages Whether to exclude tool messages.
* False by default.
* @param args.statuses What statuses to include. All by default.
* @returns The MessageDoc's in a format compatible with usePaginatedQuery.
*/
export async function listMessages(
ctx: QueryCtx | MutationCtx | ActionCtx,
component: AgentComponent,
{
threadId,
paginationOpts,
excludeToolMessages,
statuses,
}: {
threadId: string;
paginationOpts: PaginationOptions;
excludeToolMessages?: boolean;
statuses?: MessageStatus[];
},
): Promise<PaginationResult<MessageDoc>> {
if (paginationOpts.numItems === 0) {
return {
page: [],
isDone: true,
continueCursor: paginationOpts.cursor ?? "",
};
}
return ctx.runQuery(component.messages.listMessagesByThreadId, {
order: "desc",
threadId,
paginationOpts,
excludeToolMessages,
statuses,
});
}
export async function listUIMessages(
ctx: QueryCtx | MutationCtx | ActionCtx,
component: AgentComponent,
args: {
threadId: string;
paginationOpts: PaginationOptions;
},
): Promise<PaginationResult<UIMessage>> {
const result = await listMessages(ctx, component, args);
return { ...result, page: toUIMessages(result.page) };
}
export type SaveMessagesArgs = {
threadId: string;
userId?: string | null;
/**
* The message that these messages are in response to. They will be
* the same "order" as this message, at increasing stepOrder(s).
*/
promptMessageId?: string;
/**
* The messages to save.
*/
messages: (ModelMessage | Message)[];
/**
* Metadata to save with the messages. Each element corresponds to the
* message at the same index.
*/
metadata?: Omit<MessageWithMetadata, "message">[];
/**
* If true, it will fail any pending steps.
* Defaults to false.
*/
failPendingSteps?: boolean;
/**
* The embeddings to save with the messages.
*/
embeddings?: MessageEmbeddings;
/**
* A pending message ID to replace when adding messages.
*/
pendingMessageId?: string;
};
/**
* Explicitly save messages associated with the thread (& user if provided)
*/
export async function saveMessages(
ctx: MutationCtx | ActionCtx,
component: AgentComponent,
args: SaveMessagesArgs & {
/**
* The agent name to associate with the messages.
*/
agentName?: string;
},
): Promise<{ messages: MessageDoc[] }> {
let embeddings: MessageEmbeddingsWithDimension | undefined;
if (args.embeddings) {
const dimension = args.embeddings.vectors.find((v) => v !== null)?.length;
if (dimension) {
validateVectorDimension(dimension);
embeddings = {
model: args.embeddings.model,
dimension,
vectors: args.embeddings.vectors,
};
}
}
const result = await ctx.runMutation(component.messages.addMessages, {
threadId: args.threadId,
userId: args.userId ?? undefined,
agentName: args.agentName,
promptMessageId: args.promptMessageId,
pendingMessageId: args.pendingMessageId,
embeddings,
messages: await Promise.all(
args.messages.map(async (m, i) => {
const { message, fileIds } = await serializeMessage(ctx, component, m);
const base = args.metadata?.[i];
const allFileIds = [...(base?.fileIds ?? [])];
if (fileIds) allFileIds.push(...fileIds);
return parse(vMessageWithMetadata, {
...base,
message,
...(allFileIds.length > 0 ? { fileIds: allFileIds } : {}),
});
}),
),
failPendingSteps: args.failPendingSteps ?? false,
});
return { messages: result.messages };
}
export type SaveMessageArgs = {
threadId: string;
userId?: string | null;
/**
* The message that these messages are in response to. They will be
* the same "order" as this message, at increasing stepOrder(s).
*/
promptMessageId?: string;
/**
* Metadata to save with the messages. Each element corresponds to the
* message at the same index.
*/
metadata?: Omit<MessageWithMetadata, "message">;
/**
* The embedding to save with the message.
*/
embedding?: { vector: number[]; model: string };
/**
* A pending message ID to replace with this message.
*/
pendingMessageId?: string;
} & (
| {
prompt?: undefined;
/**
* The message to save.
*/
message: ModelMessage | Message;
}
| {
/*
* The prompt to save with the message.
*/
prompt: string;
message?: undefined;
}
);
/**
* Save a message to the thread.
* @param ctx A ctx object from a mutation or action.
* @param args The message and what to associate it with (user / thread)
* You can pass extra metadata alongside the message, e.g. associated fileIds.
* @returns The messageId of the saved message.
*/
export async function saveMessage(
ctx: MutationCtx | ActionCtx,
component: AgentComponent,
args: SaveMessageArgs & {
/**
* The agent name to associate with the message.
*/
agentName?: string;
},
) {
let embeddings: { vectors: number[][]; model: string } | undefined;
if (args.embedding && args.embedding.vector) {
embeddings = {
model: args.embedding.model,
vectors: [args.embedding.vector],
};
}
const { messages } = await saveMessages(ctx, component, {
threadId: args.threadId,
userId: args.userId ?? undefined,
agentName: args.agentName,
promptMessageId: args.promptMessageId,
pendingMessageId: args.pendingMessageId,
messages:
args.prompt !== undefined
? [{ role: "user", content: args.prompt }]
: [args.message],
metadata: args.metadata ? [args.metadata] : undefined,
embeddings,
});
const message = messages.at(-1)!;
return { messageId: message._id, message };
}