UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

171 lines 6.75 kB
import { stepCountIs, } from "ai"; import { serializeNewMessagesInStep, serializeObjectResult, } from "../mapping.js"; import { embedMessages, fetchContextWithPrompt } from "./search.js"; import { getModelName, getProviderName, } from "../shared.js"; import { wrapTools } from "./createTool.js"; import { assert, omit } from "convex-helpers"; import { saveInputMessages } from "./saveInputMessages.js"; export async function startGeneration(ctx, component, /** * These are the arguments you'll pass to the LLM call such as * `generateText` or `streamText`. This function will look up the context * and provide functions to save the steps, abort the generation, and more. * The type of the arguments returned infers from the type of the arguments * you pass here. */ args, { threadId, ...opts }) { const userId = opts.userId ?? (threadId && (await ctx.runQuery(component.threads.getThread, { threadId })) ?.userId) ?? undefined; const context = await fetchContextWithPrompt(ctx, component, { ...opts, userId, threadId, messages: args.messages, prompt: args.prompt, promptMessageId: args.promptMessageId, }); const saveMessages = opts.storageOptions?.saveMessages ?? "promptAndOutput"; const { promptMessageId, pendingMessage, savedMessages } = threadId && saveMessages !== "none" ? await saveInputMessages(ctx, component, { ...opts, userId, threadId, prompt: args.prompt, messages: args.messages, promptMessageId: args.promptMessageId, storageOptions: { saveMessages }, }) : { promptMessageId: args.promptMessageId, pendingMessage: undefined, savedMessages: [], }; const order = pendingMessage?.order ?? context.order; const stepOrder = pendingMessage?.stepOrder ?? context.stepOrder; let pendingMessageId = pendingMessage?._id; const model = args.model ?? opts.languageModel; assert(model, "model is required"); let activeModel = model; const fail = async (reason) => { if (pendingMessageId) { await ctx.runMutation(component.messages.finalizeMessage, { messageId: pendingMessageId, result: { status: "failed", error: reason }, }); } }; if (args.abortSignal) { const abortSignal = args.abortSignal; abortSignal.addEventListener("abort", async () => { await fail(abortSignal.reason?.toString() ?? "abortSignal"); }, { once: true }); } const toolCtx = { ...ctx, userId, threadId, promptMessageId, agent: opts.agentForToolCtx, }; const tools = wrapTools(toolCtx, args.tools); const aiArgs = { ...opts.callSettings, providerOptions: opts.providerOptions, ...omit(args, ["promptMessageId", "messages", "prompt"]), model, messages: context.messages, stopWhen: args.stopWhen ?? (opts.maxSteps ? stepCountIs(opts.maxSteps) : undefined), tools, }; if (pendingMessageId) { if (!aiArgs._internal?.generateId) { aiArgs._internal = { ...aiArgs._internal, generateId: pendingMessageId ? () => pendingMessageId ?? crypto.randomUUID() : undefined, }; } } return { args: aiArgs, order: order ?? 0, stepOrder: stepOrder ?? 0, userId, promptMessageId, getSavedMessages: () => savedMessages, updateModel: (model) => { if (model) { activeModel = model; } }, fail, save: async (toSave, createPendingMessage) => { if (threadId && saveMessages !== "none") { const serialized = "object" in toSave ? await serializeObjectResult(ctx, component, toSave.object, activeModel) : await serializeNewMessagesInStep(ctx, component, toSave.step, activeModel); const embeddings = await embedMessages(ctx, { threadId, ...opts, userId }, serialized.messages.map((m) => m.message)); if (createPendingMessage) { serialized.messages.push({ message: { role: "assistant", content: [] }, status: "pending", }); embeddings?.vectors.push(null); } const saved = await ctx.runMutation(component.messages.addMessages, { userId, threadId, agentName: opts.agentName, promptMessageId, pendingMessageId, messages: serialized.messages, embeddings, failPendingSteps: false, }); const lastMessage = saved.messages.at(-1); if (createPendingMessage) { if (lastMessage.status === "failed") { pendingMessageId = undefined; savedMessages.push(...saved.messages); await fail(lastMessage.error ?? "Aborting - the pending message was marked as failed"); } else { pendingMessageId = lastMessage._id; savedMessages.push(...saved.messages.slice(0, -1)); } } else { pendingMessageId = undefined; savedMessages.push(...saved.messages); } } const output = "object" in toSave ? toSave.object : toSave.step; if (opts.rawRequestResponseHandler) { await opts.rawRequestResponseHandler(ctx, { userId, threadId, agentName: opts.agentName, request: output.request, response: output.response, }); } if (opts.usageHandler && output.usage) { await opts.usageHandler(ctx, { userId, threadId, agentName: opts.agentName, model: getModelName(activeModel), provider: getProviderName(activeModel), usage: output.usage, providerMetadata: output.providerMetadata, }); } }, }; } //# sourceMappingURL=start.js.map