@convex-dev/agent
Version:
A agent component for Convex.
200 lines (193 loc) • 6.44 kB
text/typescript
import type {
AgentComponent,
ContextOptions,
RunActionCtx,
RunQueryCtx,
} from "./types.js";
import type { MessageDoc } from "../component/schema.js";
import type { EmbeddingModel, LanguageModel, ModelMessage } from "ai";
import { assert } from "convex-helpers";
import {
DEFAULT_MESSAGE_RANGE,
DEFAULT_RECENT_MESSAGES,
extractText,
} from "../shared.js";
import type { Message } from "../validators.js";
const DEFAULT_VECTOR_SCORE_THRESHOLD = 0.0;
export type GetEmbedding = (
text: string,
) => Promise<{
embedding: number[];
textEmbeddingModel: string | EmbeddingModel<string>;
}>;
/**
* Fetch the context messages for a thread.
* @param ctx Either a query, mutation, or action ctx.
* If it is not an action context, you can't do text or
* vector search.
* @param args The associated thread, user, message
* @returns
*/
export async function fetchContextMessages(
ctx: RunQueryCtx | RunActionCtx,
component: AgentComponent,
args: {
userId: string | undefined;
threadId: string | undefined;
messages: (ModelMessage | Message)[];
/**
* If provided, it will search for messages up to and including this message.
* Note: if this is far in the past, text and vector search results may be more
* limited, as it's post-filtering the results.
*/
upToAndIncludingMessageId?: string;
contextOptions: ContextOptions;
getEmbedding?: GetEmbedding;
},
): Promise<MessageDoc[]> {
assert(args.userId || args.threadId, "Specify userId or threadId");
const opts = args.contextOptions;
// Fetch the latest messages from the thread
let included: Set<string> | undefined;
const contextMessages: MessageDoc[] = [];
if (
args.threadId &&
(opts.recentMessages !== 0 || args.upToAndIncludingMessageId)
) {
const { page } = await ctx.runQuery(
component.messages.listMessagesByThreadId,
{
threadId: args.threadId,
excludeToolMessages: opts.excludeToolMessages,
paginationOpts: {
numItems: opts.recentMessages ?? DEFAULT_RECENT_MESSAGES,
cursor: null,
},
upToAndIncludingMessageId: args.upToAndIncludingMessageId,
order: "desc",
statuses: ["success"],
},
);
included = new Set(page.map((m) => m._id));
contextMessages.push(
// Reverse since we fetched in descending order
...page.reverse(),
);
}
if (opts.searchOptions?.textSearch || opts.searchOptions?.vectorSearch) {
const targetMessage = contextMessages.find(
(m) => m._id === args.upToAndIncludingMessageId,
)?.message;
const messagesToSearch = targetMessage ? [targetMessage] : args.messages;
if (!("runAction" in ctx)) {
throw new Error("searchUserMessages only works in an action");
}
const lastMessage = messagesToSearch.at(-1)!;
assert(lastMessage, "No messages to search");
const text = extractText(lastMessage);
assert(text, `No text to search in message ${JSON.stringify(lastMessage)}`);
assert(
!args.contextOptions?.searchOptions?.vectorSearch || "runAction" in ctx,
"You must do vector search from an action",
);
if (opts.searchOptions?.vectorSearch && !args.getEmbedding) {
throw new Error(
"You must provide an embedding and embeddingModel to use vector search",
);
}
const embeddingFields =
opts.searchOptions?.vectorSearch && text
? await args.getEmbedding?.(text)
: undefined;
const searchMessages = await ctx.runAction(
component.messages.searchMessages,
{
searchAllMessagesForUserId: opts?.searchOtherThreads
? (args.userId ??
(args.threadId &&
(
await ctx.runQuery(component.threads.getThread, {
threadId: args.threadId,
})
)?.userId))
: undefined,
threadId: args.threadId,
beforeMessageId: args.upToAndIncludingMessageId,
limit: opts.searchOptions?.limit ?? 10,
messageRange: {
...DEFAULT_MESSAGE_RANGE,
...opts.searchOptions?.messageRange,
},
text,
vectorScoreThreshold:
opts.searchOptions?.vectorScoreThreshold ??
DEFAULT_VECTOR_SCORE_THRESHOLD,
embedding: embeddingFields?.embedding,
embeddingModel: embeddingFields?.textEmbeddingModel
? getModelName(embeddingFields.textEmbeddingModel)
: undefined,
},
);
// TODO: track what messages we used for context
contextMessages.unshift(
...searchMessages.filter((m) => !included?.has(m._id)),
);
}
// Ensure we don't include tool messages without a corresponding tool call
return filterOutOrphanedToolMessages(
contextMessages.sort((a, b) =>
// Sort the raw MessageDocs by order and stepOrder
a.order === b.order ? a.stepOrder - b.stepOrder : a.order - b.order,
),
);
}
/**
* Filter out tool messages that don't have both a tool call and response.
* @param docs The messages to filter.
* @returns The filtered messages.
*/
export function filterOutOrphanedToolMessages(docs: MessageDoc[]) {
const toolCallIds = new Set<string>();
const result: MessageDoc[] = [];
for (const doc of docs) {
if (
doc.message?.role === "assistant" &&
Array.isArray(doc.message.content)
) {
for (const content of doc.message.content) {
if (content.type === "tool-call") {
toolCallIds.add(content.toolCallId);
}
}
result.push(doc);
} else if (doc.message?.role === "tool") {
if (doc.message.content.every((c) => toolCallIds.has(c.toolCallId))) {
result.push(doc);
} else {
console.debug("Filtering out orphaned tool message", doc);
}
} else {
result.push(doc);
}
}
return result;
}
export function getModelName(
embeddingModel: string | EmbeddingModel<string> | LanguageModel,
): string {
if (typeof embeddingModel === "string") {
if (embeddingModel.includes("/")) {
return embeddingModel.split("/").slice(1).join("/");
}
return embeddingModel;
}
return embeddingModel.modelId;
}
export function getProviderName(
embeddingModel: string | EmbeddingModel<string> | LanguageModel,
): string {
if (typeof embeddingModel === "string") {
return embeddingModel.split("/").at(0)!;
}
return embeddingModel.provider;
}