@convex-dev/agent
Version:
A agent component for Convex.
171 lines • 6.75 kB
JavaScript
import { stepCountIs, } from "ai";
import { serializeNewMessagesInStep, serializeObjectResult, } from "../mapping.js";
import { embedMessages, fetchContextWithPrompt } from "./search.js";
import { getModelName, getProviderName, } from "../shared.js";
import { wrapTools } from "./createTool.js";
import { assert, omit } from "convex-helpers";
import { saveInputMessages } from "./saveInputMessages.js";
export async function startGeneration(ctx, component,
/**
* These are the arguments you'll pass to the LLM call such as
* `generateText` or `streamText`. This function will look up the context
* and provide functions to save the steps, abort the generation, and more.
* The type of the arguments returned infers from the type of the arguments
* you pass here.
*/
args, { threadId, ...opts }) {
const userId = opts.userId ??
(threadId &&
(await ctx.runQuery(component.threads.getThread, { threadId }))
?.userId) ??
undefined;
const context = await fetchContextWithPrompt(ctx, component, {
...opts,
userId,
threadId,
messages: args.messages,
prompt: args.prompt,
promptMessageId: args.promptMessageId,
});
const saveMessages = opts.storageOptions?.saveMessages ?? "promptAndOutput";
const { promptMessageId, pendingMessage, savedMessages } = threadId && saveMessages !== "none"
? await saveInputMessages(ctx, component, {
...opts,
userId,
threadId,
prompt: args.prompt,
messages: args.messages,
promptMessageId: args.promptMessageId,
storageOptions: { saveMessages },
})
: {
promptMessageId: args.promptMessageId,
pendingMessage: undefined,
savedMessages: [],
};
const order = pendingMessage?.order ?? context.order;
const stepOrder = pendingMessage?.stepOrder ?? context.stepOrder;
let pendingMessageId = pendingMessage?._id;
const model = args.model ?? opts.languageModel;
assert(model, "model is required");
let activeModel = model;
const fail = async (reason) => {
if (pendingMessageId) {
await ctx.runMutation(component.messages.finalizeMessage, {
messageId: pendingMessageId,
result: { status: "failed", error: reason },
});
}
};
if (args.abortSignal) {
const abortSignal = args.abortSignal;
abortSignal.addEventListener("abort", async () => {
await fail(abortSignal.reason?.toString() ?? "abortSignal");
}, { once: true });
}
const toolCtx = {
...ctx,
userId,
threadId,
promptMessageId,
agent: opts.agentForToolCtx,
};
const tools = wrapTools(toolCtx, args.tools);
const aiArgs = {
...opts.callSettings,
providerOptions: opts.providerOptions,
...omit(args, ["promptMessageId", "messages", "prompt"]),
model,
messages: context.messages,
stopWhen: args.stopWhen ?? (opts.maxSteps ? stepCountIs(opts.maxSteps) : undefined),
tools,
};
if (pendingMessageId) {
if (!aiArgs._internal?.generateId) {
aiArgs._internal = {
...aiArgs._internal,
generateId: pendingMessageId
? () => pendingMessageId ?? crypto.randomUUID()
: undefined,
};
}
}
return {
args: aiArgs,
order: order ?? 0,
stepOrder: stepOrder ?? 0,
userId,
promptMessageId,
getSavedMessages: () => savedMessages,
updateModel: (model) => {
if (model) {
activeModel = model;
}
},
fail,
save: async (toSave, createPendingMessage) => {
if (threadId && saveMessages !== "none") {
const serialized = "object" in toSave
? await serializeObjectResult(ctx, component, toSave.object, activeModel)
: await serializeNewMessagesInStep(ctx, component, toSave.step, activeModel);
const embeddings = await embedMessages(ctx, { threadId, ...opts, userId }, serialized.messages.map((m) => m.message));
if (createPendingMessage) {
serialized.messages.push({
message: { role: "assistant", content: [] },
status: "pending",
});
embeddings?.vectors.push(null);
}
const saved = await ctx.runMutation(component.messages.addMessages, {
userId,
threadId,
agentName: opts.agentName,
promptMessageId,
pendingMessageId,
messages: serialized.messages,
embeddings,
failPendingSteps: false,
});
const lastMessage = saved.messages.at(-1);
if (createPendingMessage) {
if (lastMessage.status === "failed") {
pendingMessageId = undefined;
savedMessages.push(...saved.messages);
await fail(lastMessage.error ??
"Aborting - the pending message was marked as failed");
}
else {
pendingMessageId = lastMessage._id;
savedMessages.push(...saved.messages.slice(0, -1));
}
}
else {
pendingMessageId = undefined;
savedMessages.push(...saved.messages);
}
}
const output = "object" in toSave ? toSave.object : toSave.step;
if (opts.rawRequestResponseHandler) {
await opts.rawRequestResponseHandler(ctx, {
userId,
threadId,
agentName: opts.agentName,
request: output.request,
response: output.response,
});
}
if (opts.usageHandler && output.usage) {
await opts.usageHandler(ctx, {
userId,
threadId,
agentName: opts.agentName,
model: getModelName(activeModel),
provider: getProviderName(activeModel),
usage: output.usage,
providerMetadata: output.providerMetadata,
});
}
},
};
}
//# sourceMappingURL=start.js.map