@convex-dev/agent
Version:
A agent component for Convex.
1,115 lines (1,114 loc) • 52.6 kB
JavaScript
import { embedMany, generateObject, generateText, stepCountIs, streamObject, streamText, } from "ai";
import { assert, omit, pick } from "convex-helpers";
import { internalActionGeneric, internalMutationGeneric, } from "convex/server";
import { convexToJson, v } from "convex/values";
import { validateVectorDimension, } from "../component/vector/tables.js";
import { deserializeMessage, serializeMessage, serializeNewMessagesInStep, serializeObjectResult, } from "../mapping.js";
import { extractText, isTool } from "../shared.js";
import { vMessageEmbeddings, vMessageWithMetadata, vSafeObjectArgs, vTextArgs, } from "../validators.js";
import { createTool, wrapTools } from "./createTool.js";
import { listMessages, saveMessages, } from "./messages.js";
import { fetchContextMessages, getModelName, getProviderName, } from "./search.js";
import { DeltaStreamer, mergeTransforms, syncStreams, } from "./streaming.js";
import { createThread, getThreadMetadata } from "./threads.js";
import { inlineMessagesFiles } from "./files.js";
export { stepCountIs } from "ai";
export { vMessageDoc, vThreadDoc } from "../component/schema.js";
export { deserializeMessage, serializeDataOrUrl, serializeMessage, } from "../mapping.js";
// NOTE: these are also exported via @convex-dev/agent/validators
// a future version may put them all here or move these over there
export { vAssistantMessage, vContextOptions, vMessage, vPaginationResult, vProviderMetadata, vStorageOptions, vStreamArgs, vSystemMessage, vToolMessage, vUsage, vUserMessage, } from "../validators.js";
export { definePlaygroundAPI, } from "./definePlaygroundAPI.js";
export { getFile, storeFile } from "./files.js";
export { listMessages, saveMessage, saveMessages, } from "./messages.js";
export { fetchContextMessages, filterOutOrphanedToolMessages, } from "./search.js";
export { abortStream, listStreams, syncStreams } from "./streaming.js";
export { createThread, getThreadMetadata, updateThreadMetadata, searchThreadTitles, } from "./threads.js";
export { createTool, extractText, isTool };
export class Agent {
component;
options;
constructor(component, options) {
this.component = component;
this.options = options;
}
async createThread(ctx, args) {
const threadId = await createThread(ctx, this.component, args);
if (!("runAction" in ctx) || "workflowId" in ctx) {
return { threadId };
}
const { thread } = await this.continueThread(ctx, {
threadId,
userId: args?.userId,
});
return { threadId, thread };
}
/**
* Continues a thread using this agent. Note: threads can be continued
* by different agents. This is a convenience around calling the various
* generate and stream functions with explicit userId and threadId parameters.
* @param ctx The ctx object passed to the action handler
* @param { threadId, userId }: the thread and user to associate the messages with.
* @returns Functions bound to the userId and threadId on a `{thread}` object.
*/
async continueThread(ctx, args) {
return {
thread: {
threadId: args.threadId,
getMetadata: this.getThreadMetadata.bind(this, ctx, {
threadId: args.threadId,
}),
updateMetadata: (patch) => ctx.runMutation(this.component.threads.updateThread, {
threadId: args.threadId,
patch,
}),
generateText: this.generateText.bind(this, ctx, args),
streamText: this.streamText.bind(this, ctx, args),
generateObject: this.generateObject.bind(this, ctx, args),
streamObject: this.streamObject.bind(this, ctx, args),
},
};
}
async start(ctx,
/**
* These are the arguments you'll pass to the LLM call such as
* `generateText` or `streamText`. This function will look up the context
* and provide functions to save the steps, abort the generation, and more.
* The type of the arguments returned infers from the type of the arguments
* you pass here.
*/
args, options) {
const { threadId, ...opts } = { ...this.options, ...options };
const context = await this._saveMessagesAndFetchContext(ctx, args, {
userId: options?.userId,
threadId: options?.threadId,
...opts,
});
let pendingMessageId = context.pendingMessageId;
// TODO: extract pending message if one exists
const { args: aiArgs, promptMessageId, order, stepOrder, userId } = context;
const messages = context.savedMessages ?? [];
if (pendingMessageId) {
if (!aiArgs._internal?.generateId) {
aiArgs._internal = {
...aiArgs._internal,
generateId: () => pendingMessageId ?? crypto.randomUUID(),
};
}
}
const toolCtx = {
...ctx,
userId,
threadId,
promptMessageId,
agent: this,
};
const tools = wrapTools(toolCtx, args.tools ?? this.options.tools);
const saveOutput = opts.storageOptions?.saveMessages !== "none";
const fail = async (reason) => {
if (threadId && promptMessageId) {
console.error("RollbackMessage", promptMessageId, reason);
}
if (pendingMessageId) {
await ctx.runMutation(this.component.messages.finalizeMessage, {
messageId: pendingMessageId,
result: { status: "failed", error: reason },
});
}
};
let activeModel = aiArgs.model;
if (aiArgs.abortSignal) {
const abortSignal = aiArgs.abortSignal;
aiArgs.abortSignal.addEventListener("abort", async () => {
await fail(abortSignal.reason ?? "Aborted");
}, { once: true });
}
return {
args: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
stopWhen: args.stopWhen ?? this.options.stopWhen,
...aiArgs,
tools,
// abortSignal: abortController.signal,
},
order: order ?? 0,
stepOrder: stepOrder ?? 0,
userId,
promptMessageId,
getSavedMessages: () => messages,
updateModel: (model) => {
if (model) {
activeModel = model;
}
},
fail,
save: async (toSave, createPendingMessage) => {
if (threadId && promptMessageId && saveOutput) {
const metadata = {
// TODO: get up to date one when user selects mid-generation
model: getModelName(activeModel),
provider: getProviderName(activeModel),
};
const serialized = "object" in toSave
? await serializeObjectResult(ctx, this.component, toSave.object, metadata)
: await serializeNewMessagesInStep(ctx, this.component, toSave.step, metadata);
const embeddings = await this.generateEmbeddings(ctx, { userId, threadId }, serialized.messages.map((m) => m.message));
if (createPendingMessage) {
serialized.messages.push({
message: { role: "assistant", content: [] },
status: "pending",
});
embeddings?.vectors.push(null);
}
const saved = await ctx.runMutation(this.component.messages.addMessages, {
userId,
threadId,
agentName: this.options.name,
promptMessageId,
pendingMessageId,
messages: serialized.messages,
embeddings,
failPendingSteps: false,
});
const lastMessage = saved.messages.at(-1);
if (createPendingMessage) {
if (lastMessage.status === "failed") {
pendingMessageId = undefined;
messages.push(...saved.messages);
await fail(lastMessage.error ??
"Aborting - the pending message was marked as failed");
}
else {
pendingMessageId = lastMessage._id;
messages.push(...saved.messages.slice(0, -1));
}
}
else {
pendingMessageId = undefined;
messages.push(...saved.messages);
}
}
const output = "object" in toSave ? toSave.object : toSave.step;
if (this.options.rawRequestResponseHandler) {
await this.options.rawRequestResponseHandler(ctx, {
userId,
threadId,
agentName: this.options.name,
request: output.request,
response: output.response,
});
}
if (opts.usageHandler && output.usage) {
await opts.usageHandler(ctx, {
userId,
threadId,
agentName: this.options.name,
model: getModelName(activeModel),
provider: getProviderName(activeModel),
usage: output.usage,
providerMetadata: output.providerMetadata,
});
}
},
};
}
/**
* This behaves like {@link generateText} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
* @param ctx The context passed from the action function calling this.
* @param { userId, threadId }: The user and thread to associate the message with
* @param generateTextArgs The arguments to the generateText function, along with extra controls
* for the {@link ContextOptions} and {@link StorageOptions}.
* @returns The result of the generateText function.
*/
async generateText(ctx, threadOpts, generateTextArgs, options) {
const { args, promptMessageId, order, ...call } = await this.start(ctx, generateTextArgs, { ...threadOpts, ...options });
const steps = [];
try {
const result = (await generateText({
...args,
prepareStep: async (options) => {
const result = await generateTextArgs.prepareStep?.(options);
call.updateModel(result?.model ?? options.model);
return result;
},
onStepFinish: async (step) => {
steps.push(step);
await call.save({ step }, await willContinue(steps, args.stopWhen));
return generateTextArgs.onStepFinish?.(step);
},
}));
const metadata = {
promptMessageId,
order,
savedMessages: call.getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(result, metadata);
}
catch (error) {
await call.fail(errorToString(error));
throw error;
}
}
/**
* This behaves like {@link streamText} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async streamText(ctx, threadOpts,
/**
* The arguments to the streamText function, similar to the ai `streamText` function.
*/
streamTextArgs,
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options) {
const { threadId } = threadOpts;
const { args, userId, order, stepOrder, promptMessageId, ...call } = await this.start(ctx, streamTextArgs, { ...threadOpts, ...options });
const steps = [];
const opts = { ...this.options, ...options };
const streamer = threadId && opts.saveStreamDeltas
? new DeltaStreamer(this.component, ctx, opts.saveStreamDeltas, {
threadId,
userId,
agentName: this.options.name,
model: getModelName(args.model),
provider: getProviderName(args.model),
providerOptions: args.providerOptions,
order,
stepOrder,
abortSignal: args.abortSignal,
})
: undefined;
const result = streamText({
...args,
abortSignal: streamer?.abortController.signal ?? args.abortSignal,
// TODO: this is probably why reasoning isn't streaming
experimental_transform: mergeTransforms(options?.saveStreamDeltas, streamTextArgs.experimental_transform),
onChunk: async (event) => {
await streamer?.addParts([event.chunk]);
// console.log("onChunk", chunk);
return streamTextArgs.onChunk?.(event);
},
onError: async (error) => {
console.error("onError", error);
await call.fail(errorToString(error.error));
await streamer?.fail(errorToString(error.error));
return streamTextArgs.onError?.(error);
},
// onFinish: async (event) => {
// return streamTextArgs.onFinish?.(event);
// },
prepareStep: async (options) => {
const result = await streamTextArgs.prepareStep?.(options);
if (result) {
const model = result.model ?? options.model;
call.updateModel(model);
return result;
}
return undefined;
},
onStepFinish: async (step) => {
steps.push(step);
const createPendingMessage = await willContinue(steps, args.stopWhen);
await call.save({ step }, createPendingMessage);
if (!createPendingMessage) {
await streamer?.finish();
}
return args.onStepFinish?.(step);
},
});
const metadata = {
promptMessageId,
order,
savedMessages: call.getSavedMessages(),
messageId: promptMessageId,
};
if ((typeof options?.saveStreamDeltas === "object" &&
!options.saveStreamDeltas.returnImmediately) ||
options?.saveStreamDeltas === true) {
await result.consumeStream();
}
return Object.assign(result, metadata);
}
/**
* This behaves like {@link generateObject} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async generateObject(ctx, threadOpts,
/**
* The arguments to the generateObject function, similar to the ai.generateObject function.
*/
generateObjectArgs,
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options) {
const { args, promptMessageId, order, fail, save, getSavedMessages } = await this.start(ctx, generateObjectArgs, { ...threadOpts, ...options });
try {
const result = (await generateObject(args));
await save({ object: result });
const metadata = {
promptMessageId,
order,
savedMessages: getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(result, metadata);
}
catch (error) {
await fail(errorToString(error));
throw error;
}
}
/**
* This behaves like `streamObject` from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async streamObject(ctx, threadOpts,
/**
* The arguments to the streamObject function, similar to the ai `streamObject` function.
*/
streamObjectArgs,
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options) {
const { args, promptMessageId, order, fail, save, getSavedMessages } = await this.start(ctx, streamObjectArgs, { ...threadOpts, ...options });
const stream = streamObject({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
...args,
onError: async (error) => {
console.error(" streamObject onError", error);
// TODO: content that we have so far
// content: stream.fullStream.
await fail(errorToString(error.error));
return args.onError?.(error);
},
onFinish: async (result) => {
await save({
object: {
object: result.object,
finishReason: result.error ? "error" : "stop",
usage: result.usage,
warnings: result.warnings,
request: await stream.request,
response: result.response,
providerMetadata: result.providerMetadata,
toJsonResponse: stream.toTextStreamResponse,
},
});
return args.onFinish?.(result);
},
});
const metadata = {
promptMessageId,
order,
savedMessages: getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(stream, metadata);
}
/**
* Save a message to the thread.
* @param ctx A ctx object from a mutation or action.
* @param args The message and what to associate it with (user / thread)
* You can pass extra metadata alongside the message, e.g. associated fileIds.
* @returns The messageId of the saved message.
*/
async saveMessage(ctx, args) {
const { messages } = await this.saveMessages(ctx, {
threadId: args.threadId,
userId: args.userId,
embeddings: args.embedding
? { model: args.embedding.model, vectors: [args.embedding.vector] }
: undefined,
messages: args.prompt !== undefined
? [{ role: "user", content: args.prompt }]
: [args.message],
metadata: args.metadata ? [args.metadata] : undefined,
skipEmbeddings: args.skipEmbeddings,
pendingMessageId: args.pendingMessageId,
});
const message = messages.at(-1);
return { messageId: message._id, message };
}
/**
* Explicitly save messages associated with the thread (& user if provided)
* If you have an embedding model set, it will also generate embeddings for
* the messages.
* @param ctx The ctx parameter to a mutation or action.
* @param args The messages and context to save
* @returns
*/
async saveMessages(ctx, args) {
let embeddings;
const { skipEmbeddings, ...rest } = args;
if (args.embeddings) {
embeddings = args.embeddings;
}
else if (!skipEmbeddings && this.options.textEmbeddingModel) {
if (!("runAction" in ctx)) {
console.warn("You're trying to save messages and generate embeddings, but you're in a mutation. " +
"Pass `skipEmbeddings: true` to skip generating embeddings in the mutation and skip this warning. " +
"They will be generated lazily when you generate or stream text / objects. " +
"You can explicitly generate them asynchronously by using the scheduler to run an action later that calls `agent.generateAndSaveEmbeddings`.");
}
else if ("workflowId" in ctx) {
console.warn("You're trying to save messages and generate embeddings, but you're in a workflow. " +
"Pass `skipEmbeddings: true` to skip generating embeddings in the workflow and skip this warning. " +
"They will be generated lazily when you generate or stream text / objects. " +
"You can explicitly generate them asynchronously by using the scheduler to run an action later that calls `agent.generateAndSaveEmbeddings`.");
}
else {
embeddings = await this.generateEmbeddings(ctx, { userId: args.userId ?? undefined, threadId: args.threadId }, args.messages);
}
}
return saveMessages(ctx, this.component, {
...rest,
agentName: this.options.name,
embeddings,
});
}
/**
* List messages from a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to list messages from.
* @param args.paginationOpts Pagination options (e.g. via usePaginatedQuery).
* @param args.excludeToolMessages Whether to exclude tool messages.
* False by default.
* @param args.statuses What statuses to include. All by default.
* @returns The MessageDoc's in a format compatible with usePaginatedQuery.
*/
async listMessages(ctx, args) {
return listMessages(ctx, this.component, args);
}
/**
* A function that handles fetching stream deltas, used with the React hooks
* `useThreadMessages` or `useStreamingThreadMessages`.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to sync streams for.
* @param args.streamArgs The stream arguments with per-stream cursors.
* @returns The deltas for each stream from their existing cursor.
*/
async syncStreams(ctx, args) {
return syncStreams(ctx, this.component, args);
}
/**
* Fetch the context messages for a thread.
* @param ctx Either a query, mutation, or action ctx.
* If it is not an action context, you can't do text or
* vector search.
* @param args The associated thread, user, message
* @returns
*/
async fetchContextMessages(ctx, args) {
assert(args.userId || args.threadId, "Specify userId or threadId");
const contextOptions = {
...this.options.contextOptions,
...args.contextOptions,
};
return fetchContextMessages(ctx, this.component, {
...args,
contextOptions,
getEmbedding: async (text) => {
assert("runAction" in ctx);
assert(this.options.textEmbeddingModel, "A textEmbeddingModel is required to be set on the Agent that you're doing vector search with");
return {
embedding: (await this.doEmbed(ctx, {
userId: args.userId,
threadId: args.threadId,
values: [text],
})).embeddings[0],
textEmbeddingModel: this.options.textEmbeddingModel,
};
},
});
}
/**
* Get the metadata for a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to get the metadata for.
* @returns The metadata for the thread.
*/
async getThreadMetadata(ctx, args) {
return getThreadMetadata(ctx, this.component, args);
}
/**
* Update the metadata for a thread.
* @param ctx A ctx object from a mutation or action.
* @param args.threadId The thread to update the metadata for.
* @param args.patch The patch to apply to the thread.
* @returns The updated thread metadata.
*/
async updateThreadMetadata(ctx, args) {
const thread = await ctx.runMutation(this.component.threads.updateThread, args);
return thread;
}
/**
* Get the embeddings for a set of messages.
* @param messages The messages to get the embeddings for.
* @returns The embeddings for the messages.
*/
async generateEmbeddings(ctx, { userId, threadId, }, messages) {
if (!this.options.textEmbeddingModel) {
return undefined;
}
let embeddings;
const messageTexts = messages.map((m) => !isTool(m) && extractText(m));
// Find the indexes of the messages that have text.
const textIndexes = messageTexts
.map((t, i) => (t ? i : undefined))
.filter((i) => i !== undefined);
if (textIndexes.length === 0) {
return undefined;
}
const values = messageTexts.filter((t) => !!t);
// Then embed those messages.
const textEmbeddings = await this.doEmbed(ctx, {
userId,
threadId,
values,
});
// Then assemble the embeddings into a single array with nulls for the messages without text.
const embeddingsOrNull = Array(messages.length).fill(null);
textIndexes.forEach((i, j) => {
embeddingsOrNull[i] = textEmbeddings.embeddings[j];
});
if (textEmbeddings.embeddings.length > 0) {
const dimension = textEmbeddings.embeddings[0].length;
validateVectorDimension(dimension);
const model = getModelName(this.options.textEmbeddingModel);
embeddings = { vectors: embeddingsOrNull, dimension, model };
}
return embeddings;
}
/**
* Generate embeddings for a set of messages, and save them to the database.
* It will not generate or save embeddings for messages that already have an
* embedding.
* @param ctx The ctx parameter to an action.
* @param args The messageIds to generate embeddings for.
*/
async generateAndSaveEmbeddings(ctx, args) {
const messages = (await ctx.runQuery(this.component.messages.getMessagesByIds, {
messageIds: args.messageIds,
})).filter((m) => m !== null);
if (messages.length !== args.messageIds.length) {
throw new Error("Some messages were not found: " +
args.messageIds
.filter((id) => !messages.some((m) => m?._id === id))
.join(", "));
}
await this._generateAndSaveEmbeddings(ctx, messages);
}
async _generateAndSaveEmbeddings(ctx, messages) {
if (messages.some((m) => !m.message)) {
throw new Error("Some messages don't have a message: " +
messages
.filter((m) => !m.message)
.map((m) => m._id)
.join(", "));
}
const messagesMissingEmbeddings = messages.filter((m) => !m.embeddingId);
if (messagesMissingEmbeddings.length === 0) {
return;
}
const embeddings = await this.generateEmbeddings(ctx, {
userId: messagesMissingEmbeddings[0].userId,
threadId: messagesMissingEmbeddings[0].threadId,
}, messagesMissingEmbeddings.map((m) => deserializeMessage(m.message)));
if (!embeddings) {
if (!this.options.textEmbeddingModel) {
throw new Error("No embeddings were generated for the messages. You must pass a textEmbeddingModel to the agent constructor.");
}
throw new Error("No embeddings were generated for these messages: " +
messagesMissingEmbeddings.map((m) => m._id).join(", "));
}
await ctx.runMutation(this.component.vector.index.insertBatch, {
vectorDimension: embeddings.dimension,
vectors: messagesMissingEmbeddings
.map((m, i) => ({
messageId: m._id,
model: embeddings.model,
table: "messages",
userId: m.userId,
threadId: m.threadId,
vector: embeddings.vectors[i],
}))
.filter((v) => v.vector !== null),
});
}
/**
* Explicitly save a "step" created by the AI SDK.
* @param ctx The ctx argument to a mutation or action.
* @param args The Step generated by the AI SDK.
*/
async saveStep(ctx, args) {
const { messages } = await serializeNewMessagesInStep(ctx, this.component, args.step, {
provider: args.provider ?? getProviderName(this.options.languageModel),
model: args.model ?? getModelName(this.options.languageModel),
});
const embeddings = await this.generateEmbeddings(ctx, { userId: args.userId, threadId: args.threadId }, messages.map((m) => m.message));
return ctx.runMutation(this.component.messages.addMessages, {
userId: args.userId,
threadId: args.threadId,
agentName: this.options.name,
promptMessageId: args.promptMessageId,
messages,
embeddings,
failPendingSteps: false,
});
}
/**
* Manually save the result of a generateObject call to the thread.
* This happens automatically when using {@link generateObject} or {@link streamObject}
* from the `thread` object created by {@link continueThread} or {@link createThread}.
* @param ctx The context passed from the mutation or action function calling this.
* @param args The arguments to the saveObject function.
*/
async saveObject(ctx, args) {
const { messages } = await serializeObjectResult(ctx, this.component, args.result, {
model: args.model ??
args.metadata?.model ??
getModelName(this.options.languageModel),
provider: args.provider ??
args.metadata?.provider ??
getProviderName(this.options.languageModel),
});
const embeddings = await this.generateEmbeddings(ctx, { userId: args.userId, threadId: args.threadId }, messages.map((m) => m.message));
return ctx.runMutation(this.component.messages.addMessages, {
userId: args.userId,
threadId: args.threadId,
promptMessageId: args.promptMessageId,
failPendingSteps: false,
messages,
embeddings,
agentName: this.options.name,
});
}
/**
* Commit or rollback a message that was pending.
* This is done automatically when saving messages by default.
* If creating pending messages, you can call this when the full "transaction" is done.
* @param ctx The ctx argument to your mutation or action.
* @param args What message to save. Generally the parent message sent into
* the generateText call.
*/
async finalizeMessage(ctx, args) {
await ctx.runMutation(this.component.messages.finalizeMessage, {
messageId: args.messageId,
result: args.result,
});
}
/**
* Update a message by its id.
* @param ctx The ctx argument to your mutation or action.
* @param args The message fields to update.
*/
async updateMessage(ctx, args) {
const { message, fileIds } = await serializeMessage(ctx, this.component, args.patch.message);
await ctx.runMutation(this.component.messages.updateMessage, {
messageId: args.messageId,
patch: {
message,
fileIds: args.patch.fileIds
? [...args.patch.fileIds, ...(fileIds ?? [])]
: fileIds,
status: args.patch.status === "success" ? "success" : "failed",
error: args.patch.error,
},
});
}
/**
* Delete multiple messages by their ids, including their embeddings
* and reduce the refcount of any files they reference.
* @param ctx The ctx argument to your mutation or action.
* @param args The ids of the messages to delete.
*/
async deleteMessages(ctx, args) {
await ctx.runMutation(this.component.messages.deleteByIds, args);
}
/**
* Delete a single message by its id, including its embedding
* and reduce the refcount of any files it references.
* @param ctx The ctx argument to your mutation or action.
* @param args The id of the message to delete.
*/
async deleteMessage(ctx, args) {
await ctx.runMutation(this.component.messages.deleteByIds, {
messageIds: [args.messageId],
});
}
/**
* Delete a range of messages by their order and step order.
* Each "order" is a set of associated messages in response to the message
* at stepOrder 0.
* The (startOrder, startStepOrder) is inclusive
* and the (endOrder, endStepOrder) is exclusive.
* To delete all messages at "order" 1, you can pass:
* `{ startOrder: 1, endOrder: 2 }`
* To delete a message at step (order=1, stepOrder=1), you can pass:
* `{ startOrder: 1, startStepOrder: 1, endOrder: 1, endStepOrder: 2 }`
* To delete all messages between (1, 1) up to and including (3, 5), you can pass:
* `{ startOrder: 1, startStepOrder: 1, endOrder: 3, endStepOrder: 6 }`
*
* If it cannot do it in one transaction, it returns information you can use
* to resume the deletion.
* e.g.
* ```ts
* let isDone = false;
* let lastOrder = args.startOrder;
* let lastStepOrder = args.startStepOrder ?? 0;
* while (!isDone) {
* // eslint-disable-next-line @typescript-eslint/no-explicit-any
* ({ isDone, lastOrder, lastStepOrder } = await agent.deleteMessageRange(
* ctx,
* {
* threadId: args.threadId,
* startOrder: lastOrder,
* startStepOrder: lastStepOrder,
* endOrder: args.endOrder,
* endStepOrder: args.endStepOrder,
* }
* ));
* }
* ```
* @param ctx The ctx argument to your mutation or action.
* @param args The range of messages to delete.
*/
async deleteMessageRange(ctx, args) {
return ctx.runMutation(this.component.messages.deleteByOrder, {
threadId: args.threadId,
startOrder: args.startOrder,
startStepOrder: args.startStepOrder,
endOrder: args.endOrder,
endStepOrder: args.endStepOrder,
});
}
/**
* Delete a thread and all its messages and streams asynchronously (in batches)
* This uses a mutation to that processes one page and recursively queues the
* next page for deletion.
* @param ctx The ctx argument to your mutation or action.
* @param args The id of the thread to delete and optionally the page size to use for the delete.
*/
async deleteThreadAsync(ctx, args) {
await ctx.runMutation(this.component.threads.deleteAllForThreadIdAsync, {
threadId: args.threadId,
limit: args.pageSize,
});
}
/**
* Delete a thread and all its messages and streams synchronously.
* This uses an action to iterate through all pages. If the action fails
* partway, it will not automatically restart.
* @param ctx The ctx argument to your action.
* @param args The id of the thread to delete and optionally the page size to use for the delete.
*/
async deleteThreadSync(ctx, args) {
await ctx.runAction(this.component.threads.deleteAllForThreadIdSync, {
threadId: args.threadId,
limit: args.pageSize,
});
}
async _saveMessagesAndFetchContext(ctx, args, { userId: argsUserId, threadId, contextOptions, storageOptions, }) {
// If only a promptMessageId is provided, this will be empty.
const messages = args.messages ?? [];
const prompt = !args.prompt
? []
: Array.isArray(args.prompt)
? args.prompt
: [{ role: "user", content: args.prompt }];
const userId = argsUserId ??
(threadId &&
(await ctx.runQuery(this.component.threads.getThread, { threadId }))
?.userId) ??
undefined;
// If only a messageId is provided, this will add that message to the end.
const contextMessages = await this.fetchContextMessages(ctx, {
userId,
threadId,
upToAndIncludingMessageId: args.promptMessageId,
messages,
contextOptions,
});
// If it was a promptMessageId, pop it off context messages
// and add to the end of messages.
const promptMessageIndex = args.promptMessageId
? contextMessages.findIndex((m) => m._id === args.promptMessageId)
: -1;
const promptMessage = promptMessageIndex !== -1
? contextMessages.splice(promptMessageIndex, 1)[0]
: undefined;
let promptMessageId = promptMessage?._id;
let order = promptMessage?.order;
let stepOrder = promptMessage?.stepOrder;
let savedMessages = undefined;
let pendingMessageId = undefined;
if (threadId && storageOptions?.saveMessages !== "none") {
let saved;
if (messages.length + prompt.length &&
// If it was a promptMessageId, we don't want to save it again.
(!args.promptMessageId || storageOptions?.saveMessages === "all")) {
const saveAll = storageOptions?.saveMessages === "all";
const coreMessages = [...messages, ...prompt];
const toSave = saveAll ? coreMessages : coreMessages.slice(-1);
const metadata = Array.from({ length: toSave.length }, () => ({}));
saved = await this.saveMessages(ctx, {
threadId,
userId,
messages: [...toSave, { role: "assistant", content: [] }],
metadata: [...metadata, { status: "pending" }],
failPendingSteps: true,
pendingMessageId: args.pendingMessageId,
});
promptMessageId = saved.messages.at(-2)._id;
}
else {
saved = await this.saveMessages(ctx, {
threadId,
userId,
messages: [{ role: "assistant", content: [] }],
metadata: [{ status: "pending" }],
failPendingSteps: true,
pendingMessageId: args.pendingMessageId,
});
}
pendingMessageId = saved.messages.at(-1)._id;
order = saved.messages.at(-1).order;
stepOrder = saved.messages.at(-1).stepOrder;
// Don't return the pending message
savedMessages = saved.messages.slice(0, -1);
}
if (promptMessage?.message) {
if (!args.prompt) {
// If they override the prompt, we skip the existing prompt message.
messages.push(deserializeMessage(promptMessage.message));
}
// Lazily generate embeddings for the prompt message, if it doesn't have
// embeddings yet. This can happen if the message was saved in a mutation
// where the LLM is not available.
if (!promptMessage.embeddingId && this.options.textEmbeddingModel) {
await this._generateAndSaveEmbeddings(ctx, [promptMessage]);
}
}
const prePrompt = contextMessages.map((m) => deserializeMessage(m.message));
let existingResponses = [];
if (promptMessageIndex !== -1) {
// pull any messages that already responded to the prompt off
// and add them after the prompt
existingResponses = prePrompt.splice(promptMessageIndex);
}
let processedMessages = [
...prePrompt,
...messages,
...prompt,
...existingResponses,
];
if (promptMessageIndex === -1) {
processedMessages.push(...prompt);
}
else {
// We add the prompt where the prompt message was
processedMessages.splice(promptMessageIndex, 0, ...prompt);
}
// Process messages to inline localhost files (if not, file urls pointing to localhost will be sent to LLM providers)
if (process.env.CONVEX_CLOUD_URL?.startsWith("http://127.0.0.1")) {
processedMessages = await inlineMessagesFiles(processedMessages);
}
const { prompt: _, model, ...rest } = args;
return {
args: {
...this.options.callSettings,
...this.options.providerOptions,
...rest,
model: model ?? this.options.languageModel,
system: args.system ?? this.options.instructions,
messages: processedMessages,
},
userId,
promptMessageId,
pendingMessageId,
savedMessages,
order,
stepOrder,
};
}
async doEmbed(ctx, options) {
const embeddingModel = this.options.textEmbeddingModel;
assert(embeddingModel, "a textEmbeddingModel is required to be set on the Agent that you're doing vector search with");
const result = await embedMany({
...this.options.callSettings,
model: embeddingModel,
values: options.values,
abortSignal: options.abortSignal,
headers: options.headers,
});
if (this.options.usageHandler && result.usage) {
await this.options.usageHandler(ctx, {
userId: options.userId,
threadId: options.threadId,
agentName: this.options.name,
model: getModelName(embeddingModel),
provider: getProviderName(embeddingModel),
providerMetadata: undefined,
usage: {
inputTokens: result.usage.tokens,
outputTokens: 0,
totalTokens: result.usage.tokens,
},
});
}
return { embeddings: result.embeddings };
}
/**
* WORKFLOW UTILITIES
*/
/**
* Create a mutation that creates a thread so you can call it from a Workflow.
* e.g.
* ```ts
* // in convex/foo.ts
* export const createThread = weatherAgent.createThreadMutation();
*
* const workflow = new WorkflowManager(components.workflow);
* export const myWorkflow = workflow.define({
* args: {},
* handler: async (step) => {
* const { threadId } = await step.runMutation(internal.foo.createThread);
* // use the threadId to generate text, object, etc.
* },
* });
* ```
* @returns A mutation that creates a thread.
*/
createThreadMutation() {
return internalMutationGeneric({
args: {
userId: v.optional(v.string()),
title: v.optional(v.string()),
summary: v.optional(v.string()),
},
handler: async (ctx, args) => {
const { threadId } = await this.createThread(ctx, args);
return { threadId };
},
});
}
/**
* Create an action out of this agent so you can call it from workflows or other actions
* without a wrapping function.
* @param spec Configuration for the agent acting as an action, including
* {@link ContextOptions}, {@link StorageOptions}, and {@link stopWhen}.
*/
asTextAction(spec, overrides) {
return internalActionGeneric({
args: vTextArgs,
handler: async (ctx_, args) => {
const stream = args.stream === true ? spec?.stream || true : (spec?.stream ?? false);
const targetArgs = { userId: args.userId, threadId: args.threadId };
const llmArgs = {
stopWhen: spec?.stopWhen ?? this.options.stopWhen,
...overrides,
...omit(args, ["storageOptions", "contextOptions"]),
messages: args.messages?.map(deserializeMessage),
prompt: Array.isArray(args.prompt)
? args.prompt.map(deserializeMessage)
: args.prompt,
toolChoice: args.toolChoice,
};
if (args.maxSteps) {
llmArgs.stopWhen = stepCountIs(args.maxSteps);
}
const opts = {
...this.options,
...pick(spec, ["contextOptions", "storageOptions"]),
...pick(args, ["contextOptions", "storageOptions"]),
saveStreamDeltas: stream,
};
const ctx = (spec?.customCtx
? { ...ctx_, ...spec.customCtx(ctx_, targetArgs, llmArgs) }
: ctx_);
if (stream) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result = await this.streamText(ctx, targetArgs, llmArgs, opts);
await result.consumeStream();
return {
text: await result.text,
promptMessageId: result.promptMessageId,
order: result.order,
finishReason: await result.finishReason,
warnings: result.warnings,
savedMessageIds: result.savedMessages?.map((m) => m._id) ?? [],
};
}
else {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const res = await this.generateText(ctx, targetArgs, llmArgs, opts);
return {
text: res.text,
promptMessageId: res.promptMessageId,
order: res.order,
finishReason: res.finishReason,
warnings: res.warnings,
savedMessageIds: res.savedMessages?.map((m) => m._id) ?? [],
};
}
},
});
}
/**
* Create an action that generates an object out of this agent so you can call
* it from workflows or other actions without a wrapping function.
* @param spec Configuration for the agent acting as an action, including
* the normal parameters to {@link generateObject}, plus {@link ContextOptions}
* and stopWhen.
*/
asObjectAction(objectArgs, options) {
return internalActionGeneric({
args: vSafeObjectArgs,
handler: async (ctx_, args) => {
const { userId, threadId, callSettings, ...rest } = args;
const overrides = pick(rest, ["contextOptions", "storageOptions"]);
const targetArgs = { userId, threadId };
const llmArgs = {
...objectArgs,
...callSettings,
...omit(rest, ["storageOptions", "contextOptions"]),
messages: args.messages?.map(deserializeMessage),
prompt: Array.isArray(args.prompt)
? args.prompt.map(deserializeMessage)
: args.prompt,
};
const ctx = (options?.customCtx
? { ...ctx_, ...options.customCtx(ctx_, targetArgs, llmArgs) }
: ctx_);
const value = await this.generateObject(ctx, targetArgs, llmArgs, {
...this.options,
...options,
...overrides,
});
return {
object: convexToJson(value.object),
promptMessageId: value.promptMessageId,
order: value.order,
finishReason: value.finishReason,
warnings: value.warnings,
savedMessageIds: value.savedMessages?.map((m) => m._id) ?? [],