@convex-dev/agent
Version:
A agent component for Convex.
1,489 lines (1,457 loc) • 70.1 kB
text/typescript
import type {
FlexibleSchema,
IdGenerator,
InferSchema,
ProviderOptions,
} from "@ai-sdk/provider-utils";
import type {
CallSettings,
EmbeddingModel,
GenerateObjectResult,
GenerateTextResult,
LanguageModel,
ModelMessage,
StepResult,
StopCondition,
StreamTextResult,
ToolChoice,
ToolSet,
} from "ai";
import {
embedMany,
generateObject,
generateText,
stepCountIs,
streamObject,
streamText,
} from "ai";
import { assert, omit, pick } from "convex-helpers";
import {
internalActionGeneric,
internalMutationGeneric,
type GenericDataModel,
type PaginationOptions,
type PaginationResult,
type WithoutSystemFields,
} from "convex/server";
import { convexToJson, v, type Value } from "convex/values";
import type { MessageDoc, ThreadDoc } from "../component/schema.js";
import type { threadFieldsSupportingPatch } from "../component/threads.js";
import {
validateVectorDimension,
type VectorDimension,
} from "../component/vector/tables.js";
import {
deserializeMessage,
serializeMessage,
serializeNewMessagesInStep,
serializeObjectResult,
} from "../mapping.js";
import { extractText, isTool } from "../shared.js";
import {
vMessageEmbeddings,
vMessageWithMetadata,
vSafeObjectArgs,
vTextArgs,
type Message,
type MessageStatus,
type MessageWithMetadata,
type ProviderMetadata,
type StreamArgs,
type Usage,
} from "../validators.js";
import { createTool, wrapTools, type ToolCtx } from "./createTool.js";
import {
listMessages,
saveMessages,
type SaveMessageArgs,
type SaveMessagesArgs,
} from "./messages.js";
import {
fetchContextMessages,
getModelName,
getProviderName,
} from "./search.js";
import {
DeltaStreamer,
mergeTransforms,
syncStreams,
type StreamingOptions,
} from "./streaming.js";
import { createThread, getThreadMetadata } from "./threads.js";
import type {
ActionCtx,
AgentComponent,
ContextOptions,
DefaultObjectSchema,
GenerationOutputMetadata,
MaybeCustomCtx,
GenerateObjectArgs,
ObjectMode,
ObjectSchema,
Options,
RawRequestResponseHandler,
RunActionCtx,
RunMutationCtx,
RunQueryCtx,
StorageOptions,
StreamingTextArgs,
StreamObjectArgs,
SyncStreamsReturnValue,
TextArgs,
Thread,
UsageHandler,
UserActionCtx,
} from "./types.js";
import { inlineMessagesFiles } from "./files.js";
import type { DataModel } from "../component/_generated/dataModel.js";
export { stepCountIs } from "ai";
export { vMessageDoc, vThreadDoc } from "../component/schema.js";
export {
deserializeMessage,
serializeDataOrUrl,
serializeMessage,
} from "../mapping.js";
// NOTE: these are also exported via @convex-dev/agent/validators
// a future version may put them all here or move these over there
export {
vAssistantMessage,
vContextOptions,
vMessage,
vPaginationResult,
vProviderMetadata,
vStorageOptions,
vStreamArgs,
vSystemMessage,
vToolMessage,
vUsage,
vUserMessage,
} from "../validators.js";
export type { ToolCtx } from "./createTool.js";
export {
definePlaygroundAPI,
type AgentsFn,
type PlaygroundAPI,
} from "./definePlaygroundAPI.js";
export { getFile, storeFile } from "./files.js";
export {
listMessages,
saveMessage,
saveMessages,
type SaveMessageArgs,
type SaveMessagesArgs,
} from "./messages.js";
export {
fetchContextMessages,
filterOutOrphanedToolMessages,
} from "./search.js";
export { abortStream, listStreams, syncStreams } from "./streaming.js";
export {
createThread,
getThreadMetadata,
updateThreadMetadata,
searchThreadTitles,
} from "./threads.js";
export { createTool, extractText, isTool };
export type {
AgentComponent,
ContextOptions,
MessageDoc,
ProviderMetadata,
RawRequestResponseHandler,
StorageOptions,
StreamArgs,
SyncStreamsReturnValue,
Thread,
ThreadDoc,
Usage,
UsageHandler,
};
export type Config = {
/**
* The LLM model to use for generating / streaming text and objects.
* e.g.
* import { openai } from "@ai-sdk/openai"
* const myAgent = new Agent(components.agent, {
* languageModel: openai.chat("gpt-4o-mini"),
*/
languageModel?: LanguageModel;
/**
* The model to use for text embeddings. Optional.
* If specified, it will use this for generating vector embeddings
* of chats, and can opt-in to doing vector search for automatic context
* on generateText, etc.
* e.g.
* import { openai } from "@ai-sdk/openai"
* const myAgent = new Agent(components.agent, {
* ...
* textEmbeddingModel: openai.embedding("text-embedding-3-small")
*/
textEmbeddingModel?: EmbeddingModel<string>;
/**
* Options to determine what messages are included as context in message
* generation. To disable any messages automatically being added, pass:
* { recentMessages: 0 }
*/
contextOptions?: ContextOptions;
/**
* Determines whether messages are automatically stored when passed as
* arguments or generated.
*/
storageOptions?: StorageOptions;
/**
* The usage handler to use for this agent.
*/
usageHandler?: UsageHandler;
/**
* Called for each LLM request/response, so you can do things like
* log the raw request body or response headers to a table, or logs.
*/
rawRequestResponseHandler?: RawRequestResponseHandler;
/**
* Default provider options to pass for the LLM calls.
* This can be overridden at each generate/stream callsite on a per-field
* basis. To clear a default setting, you'll need to pass `undefined`.
*/
providerOptions?: ProviderOptions;
/**
* The default settings to use for the LLM calls.
* This can be overridden at each generate/stream callsite on a per-field
* basis. To clear a default setting, you'll need to pass `undefined`.
*/
callSettings?: CallSettings;
};
export class Agent<
/**
* You can require that all `ctx` args to generateText & streamText
* have a certain shape by passing a type here.
* e.g.
* ```ts
* const myAgent = new Agent<{ orgId: string }>(...);
* ```
* This is useful if you want to share that type in `createTool`
* e.g.
* ```ts
* type MyCtx = ToolCtx & { orgId: string };
* const myTool = createTool({
* args: z.object({...}),
* description: "...",
* handler: async (ctx: MyCtx, args) => {
* // use ctx.orgId
* },
* });
*/
CustomCtx extends object = object,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
AgentTools extends ToolSet = any,
> {
constructor(
public component: AgentComponent,
public options: Config & {
/**
* The name for the agent. This will be attributed on each message
* created by this agent.
*/
name: string;
/**
* The LLM model to use for generating / streaming text and objects.
* e.g.
* import { openai } from "@ai-sdk/openai"
* const myAgent = new Agent(components.agent, {
* languageModel: openai.chat("gpt-4o-mini"),
*/
languageModel: LanguageModel;
/**
* The default system prompt to put in each request.
* Override per-prompt by passing the "system" parameter.
*/
instructions?: string;
/**
* Tools that the agent can call out to and get responses from.
* They can be AI SDK tools (import {tool} from "ai")
* or tools that have Convex context
* (import { createTool } from "@convex-dev/agent")
*/
tools?: AgentTools;
/**
* When generating or streaming text with tools available, this
* determines when to stop. Defaults to stepCountIs(1).
*/
stopWhen?: StopCondition<AgentTools> | Array<StopCondition<AgentTools>>;
/**
* @deprecated Use `languageEmbeddingModel` instead.
*/
chat?: LanguageModel;
},
) {}
/**
* Start a new thread with the agent. This will have a fresh history, though if
* you pass in a userId you can have it search across other threads for relevant
* messages as context for the LLM calls.
* @param ctx The context of the Convex function. From an action, you can thread
* with the agent. From a mutation, you can start a thread and save the threadId
* to pass to continueThread later.
* @param args The thread metadata.
* @returns The threadId of the new thread and the thread object.
*/
async createThread(
ctx: RunActionCtx & CustomCtx,
args?: {
/**
* The userId to associate with the thread. If not provided, the thread will be
* anonymous.
*/
userId?: string | null;
/**
* The title of the thread. Not currently used for anything.
*/
title?: string;
/**
* The summary of the thread. Not currently used for anything.
*/
summary?: string;
},
): Promise<{ threadId: string; thread: Thread<AgentTools> }>;
/**
* Start a new thread with the agent. This will have a fresh history, though if
* you pass in a userId you can have it search across other threads for relevant
* messages as context for the LLM calls.
* @param ctx The context of the Convex function. From a mutation, you can
* start a thread and save the threadId to pass to continueThread later.
* @param args The thread metadata.
* @returns The threadId of the new thread.
*/
async createThread(
ctx: RunMutationCtx,
args?: {
/**
* The userId to associate with the thread. If not provided, the thread will be
* anonymous.
*/
userId?: string | null;
/**
* The title of the thread. Not currently used for anything.
*/
title?: string;
/**
* The summary of the thread. Not currently used for anything.
*/
summary?: string;
},
): Promise<{ threadId: string }>;
async createThread(
ctx: (ActionCtx & CustomCtx) | RunMutationCtx,
args?: { userId: string | null; title?: string; summary?: string },
): Promise<{ threadId: string; thread?: Thread<AgentTools> }> {
const threadId = await createThread(ctx, this.component, args);
if (!("runAction" in ctx) || "workflowId" in ctx) {
return { threadId };
}
const { thread } = await this.continueThread(ctx, {
threadId,
userId: args?.userId,
});
return { threadId, thread };
}
/**
* Continues a thread using this agent. Note: threads can be continued
* by different agents. This is a convenience around calling the various
* generate and stream functions with explicit userId and threadId parameters.
* @param ctx The ctx object passed to the action handler
* @param { threadId, userId }: the thread and user to associate the messages with.
* @returns Functions bound to the userId and threadId on a `{thread}` object.
*/
async continueThread(
ctx: ActionCtx & CustomCtx,
args: {
/**
* The associated thread created by {@link createThread}
*/
threadId: string;
/**
* If supplied, the userId can be used to search across other threads for
* relevant messages from the same user as context for the LLM calls.
*/
userId?: string | null;
},
): Promise<{ thread: Thread<AgentTools> }> {
return {
thread: {
threadId: args.threadId,
getMetadata: this.getThreadMetadata.bind(this, ctx, {
threadId: args.threadId,
}),
updateMetadata: (patch: Partial<WithoutSystemFields<ThreadDoc>>) =>
ctx.runMutation(this.component.threads.updateThread, {
threadId: args.threadId,
patch,
}),
generateText: this.generateText.bind(this, ctx, args),
streamText: this.streamText.bind(this, ctx, args),
generateObject: this.generateObject.bind(this, ctx, args),
streamObject: this.streamObject.bind(this, ctx, args),
} as Thread<AgentTools>,
};
}
async start<TOOLS extends ToolSet | undefined, T>(
ctx: ActionCtx & CustomCtx,
/**
* These are the arguments you'll pass to the LLM call such as
* `generateText` or `streamText`. This function will look up the context
* and provide functions to save the steps, abort the generation, and more.
* The type of the arguments returned infers from the type of the arguments
* you pass here.
*/
args: T & {
/**
* If provided, this message will be used as the "prompt" for the LLM call,
* instead of the prompt or messages.
* This is useful if you want to first save a user message, then use it as
* the prompt for the LLM call in another call.
*/
promptMessageId?: string;
/**
* The model to use for the LLM calls. This will override the model specified
* in the Agent constructor.
*/
model?: LanguageModel;
/**
* The tools to use for the tool calls. This will override tools specified
* in the Agent constructor or createThread / continueThread.
*/
tools?: TOOLS;
/**
* The single prompt message to use for the LLM call. This will be the
* last message in the context. If it's a string, it will be a user role.
*/
prompt?: string | (ModelMessage | Message)[];
/**
* If provided alongside prompt, the ordering will be:
* 1. system prompt
* 2. search context
* 3. recent messages
* 4. these messages
* 5. prompt messages, including those already on the same `order` as
* the promptMessageId message, if provided.
*/
messages?: (ModelMessage | Message)[];
/**
* This will be the first message in the context, and overrides the
* agent's instructions.
*/
system?: string;
/**
* The abort signal to be passed to the LLM call. If triggered, it will
* mark the pending message as failed. If the generation is asynchronously
* aborted, it will trigger this signal when detected.
*/
abortSignal?: AbortSignal;
// We optimistically override the generateId function to use the pending
// message id.
_internal?: { generateId?: IdGenerator };
},
options?: Options & { userId?: string | null; threadId?: string },
): Promise<{
args: T & {
system?: string;
model: LanguageModel;
messages: ModelMessage[];
// abortSignal?: AbortSignal;
tools?: TOOLS extends undefined ? AgentTools : TOOLS;
} & CallSettings;
order: number;
stepOrder: number;
userId: string | undefined;
promptMessageId: string | undefined;
updateModel: (model: LanguageModel | undefined) => void;
save: <TOOLS extends ToolSet>(
toSave:
| { step: StepResult<TOOLS> }
| { object: GenerateObjectResult<unknown> },
createPendingMessage?: boolean,
) => Promise<void>;
fail: (reason: string) => Promise<void>;
getSavedMessages: () => MessageDoc[];
}> {
const { threadId, ...opts } = { ...this.options, ...options };
const context = await this._saveMessagesAndFetchContext(ctx, args, {
userId: options?.userId,
threadId: options?.threadId,
...opts,
});
let pendingMessageId = context.pendingMessageId;
// TODO: extract pending message if one exists
const { args: aiArgs, promptMessageId, order, stepOrder, userId } = context;
const messages = context.savedMessages ?? [];
if (pendingMessageId) {
if (!aiArgs._internal?.generateId) {
aiArgs._internal = {
...aiArgs._internal,
generateId: () => pendingMessageId ?? crypto.randomUUID(),
};
}
}
const toolCtx = {
...(ctx as UserActionCtx & CustomCtx),
userId,
threadId,
promptMessageId,
agent: this,
} satisfies ToolCtx;
type Tools = TOOLS extends undefined ? AgentTools : TOOLS;
const tools = wrapTools(toolCtx, args.tools ?? this.options.tools) as Tools;
const saveOutput = opts.storageOptions?.saveMessages !== "none";
const fail = async (reason: string) => {
if (threadId && promptMessageId) {
console.error("RollbackMessage", promptMessageId, reason);
}
if (pendingMessageId) {
await ctx.runMutation(this.component.messages.finalizeMessage, {
messageId: pendingMessageId,
result: { status: "failed", error: reason },
});
}
};
let activeModel = aiArgs.model;
if (aiArgs.abortSignal) {
const abortSignal = aiArgs.abortSignal;
aiArgs.abortSignal.addEventListener(
"abort",
async () => {
await fail(abortSignal.reason ?? "Aborted");
},
{ once: true },
);
}
return {
args: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
stopWhen: (args as any).stopWhen ?? this.options.stopWhen,
...aiArgs,
tools,
// abortSignal: abortController.signal,
},
order: order ?? 0,
stepOrder: stepOrder ?? 0,
userId,
promptMessageId,
getSavedMessages: () => messages,
updateModel: (model: LanguageModel | undefined) => {
if (model) {
activeModel = model;
}
},
fail,
save: async <TOOLS extends ToolSet>(
toSave:
| { step: StepResult<TOOLS> }
| { object: GenerateObjectResult<unknown> },
createPendingMessage?: boolean,
) => {
if (threadId && promptMessageId && saveOutput) {
const metadata = {
// TODO: get up to date one when user selects mid-generation
model: getModelName(activeModel),
provider: getProviderName(activeModel),
};
const serialized =
"object" in toSave
? await serializeObjectResult(
ctx,
this.component,
toSave.object,
metadata,
)
: await serializeNewMessagesInStep(
ctx,
this.component,
toSave.step,
metadata,
);
const embeddings = await this.generateEmbeddings(
ctx,
{ userId, threadId },
serialized.messages.map((m) => m.message),
);
if (createPendingMessage) {
serialized.messages.push({
message: { role: "assistant", content: [] },
status: "pending",
});
embeddings?.vectors.push(null);
}
const saved = await ctx.runMutation(
this.component.messages.addMessages,
{
userId,
threadId,
agentName: this.options.name,
promptMessageId,
pendingMessageId,
messages: serialized.messages,
embeddings,
failPendingSteps: false,
},
);
const lastMessage = saved.messages.at(-1)!;
if (createPendingMessage) {
if (lastMessage.status === "failed") {
pendingMessageId = undefined;
messages.push(...saved.messages);
await fail(
lastMessage.error ??
"Aborting - the pending message was marked as failed",
);
} else {
pendingMessageId = lastMessage._id;
messages.push(...saved.messages.slice(0, -1));
}
} else {
pendingMessageId = undefined;
messages.push(...saved.messages);
}
}
const output = "object" in toSave ? toSave.object : toSave.step;
if (this.options.rawRequestResponseHandler) {
await this.options.rawRequestResponseHandler(ctx, {
userId,
threadId,
agentName: this.options.name,
request: output.request,
response: output.response,
});
}
if (opts.usageHandler && output.usage) {
await opts.usageHandler(ctx, {
userId,
threadId,
agentName: this.options.name,
model: getModelName(activeModel),
provider: getProviderName(activeModel),
usage: output.usage,
providerMetadata: output.providerMetadata,
});
}
},
};
}
/**
* This behaves like {@link generateText} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
* @param ctx The context passed from the action function calling this.
* @param { userId, threadId }: The user and thread to associate the message with
* @param generateTextArgs The arguments to the generateText function, along with extra controls
* for the {@link ContextOptions} and {@link StorageOptions}.
* @returns The result of the generateText function.
*/
async generateText<
TOOLS extends ToolSet | undefined = undefined,
OUTPUT = never,
OUTPUT_PARTIAL = never,
>(
ctx: ActionCtx & CustomCtx,
threadOpts: { userId?: string | null; threadId?: string },
generateTextArgs: TextArgs<AgentTools, TOOLS, OUTPUT, OUTPUT_PARTIAL>,
options?: Options,
): Promise<
GenerateTextResult<TOOLS extends undefined ? AgentTools : TOOLS, OUTPUT> &
GenerationOutputMetadata
> {
const { args, promptMessageId, order, ...call } = await this.start(
ctx,
generateTextArgs,
{ ...threadOpts, ...options },
);
type Tools = TOOLS extends undefined ? AgentTools : TOOLS;
const steps: StepResult<Tools>[] = [];
try {
const result = (await generateText<Tools, OUTPUT, OUTPUT_PARTIAL>({
...args,
prepareStep: async (options) => {
const result = await generateTextArgs.prepareStep?.(options);
call.updateModel(result?.model ?? options.model);
return result;
},
onStepFinish: async (step) => {
steps.push(step);
await call.save({ step }, await willContinue(steps, args.stopWhen));
return generateTextArgs.onStepFinish?.(step);
},
})) as GenerateTextResult<Tools, OUTPUT>;
const metadata: GenerationOutputMetadata = {
promptMessageId,
order,
savedMessages: call.getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(result, metadata);
} catch (error) {
await call.fail(errorToString(error));
throw error;
}
}
/**
* This behaves like {@link streamText} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async streamText<
TOOLS extends ToolSet | undefined = undefined,
OUTPUT = never,
PARTIAL_OUTPUT = never,
>(
ctx: ActionCtx & CustomCtx,
threadOpts: { userId?: string | null; threadId?: string },
/**
* The arguments to the streamText function, similar to the ai `streamText` function.
*/
streamTextArgs: StreamingTextArgs<
AgentTools,
TOOLS,
OUTPUT,
PARTIAL_OUTPUT
>,
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options?: Options & {
/**
* Whether to save incremental data (deltas) from streaming responses.
* Defaults to false.
* If false, it will not save any deltas to the database.
* If true, it will save deltas with {@link DEFAULT_STREAMING_OPTIONS}.
*
* Regardless of this option, when streaming you are able to use this
* `streamText` function as you would with the "ai" package's version:
* iterating over the text, streaming it over HTTP, etc.
*/
saveStreamDeltas?: boolean | StreamingOptions;
},
): Promise<
StreamTextResult<
TOOLS extends undefined ? AgentTools : TOOLS,
PARTIAL_OUTPUT
> &
GenerationOutputMetadata
> {
const { threadId } = threadOpts;
const { args, userId, order, stepOrder, promptMessageId, ...call } =
await this.start(ctx, streamTextArgs, { ...threadOpts, ...options });
type Tools = TOOLS extends undefined ? AgentTools : TOOLS;
const steps: StepResult<Tools>[] = [];
const opts = { ...this.options, ...options };
const streamer =
threadId && opts.saveStreamDeltas
? new DeltaStreamer(this.component, ctx, opts.saveStreamDeltas, {
threadId,
userId,
agentName: this.options.name,
model: getModelName(args.model),
provider: getProviderName(args.model),
providerOptions: args.providerOptions,
order,
stepOrder,
abortSignal: args.abortSignal,
})
: undefined;
const result = streamText({
...args,
abortSignal: streamer?.abortController.signal ?? args.abortSignal,
// TODO: this is probably why reasoning isn't streaming
experimental_transform: mergeTransforms(
options?.saveStreamDeltas,
streamTextArgs.experimental_transform,
),
onChunk: async (event) => {
await streamer?.addParts([event.chunk]);
// console.log("onChunk", chunk);
return streamTextArgs.onChunk?.(event);
},
onError: async (error) => {
console.error("onError", error);
await call.fail(errorToString(error.error));
await streamer?.fail(errorToString(error.error));
return streamTextArgs.onError?.(error);
},
// onFinish: async (event) => {
// return streamTextArgs.onFinish?.(event);
// },
prepareStep: async (options) => {
const result = await streamTextArgs.prepareStep?.(options);
if (result) {
const model = result.model ?? options.model;
call.updateModel(model);
return result;
}
return undefined;
},
onStepFinish: async (step) => {
steps.push(step);
const createPendingMessage = await willContinue(steps, args.stopWhen);
await call.save({ step }, createPendingMessage);
if (!createPendingMessage) {
await streamer?.finish();
}
return args.onStepFinish?.(step);
},
}) as StreamTextResult<
TOOLS extends undefined ? AgentTools : TOOLS,
PARTIAL_OUTPUT
>;
const metadata: GenerationOutputMetadata = {
promptMessageId,
order,
savedMessages: call.getSavedMessages(),
messageId: promptMessageId,
};
if (
(typeof options?.saveStreamDeltas === "object" &&
!options.saveStreamDeltas.returnImmediately) ||
options?.saveStreamDeltas === true
) {
await result.consumeStream();
}
return Object.assign(result, metadata);
}
/**
* This behaves like {@link generateObject} from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async generateObject<
SCHEMA extends ObjectSchema = DefaultObjectSchema,
OUTPUT extends ObjectMode = InferSchema<SCHEMA> extends string
? "enum"
: "object",
RESULT = OUTPUT extends "array"
? Array<InferSchema<SCHEMA>>
: InferSchema<SCHEMA>,
>(
ctx: ActionCtx & CustomCtx,
threadOpts: { userId?: string | null; threadId?: string },
/**
* The arguments to the generateObject function, similar to the ai.generateObject function.
*/
generateObjectArgs: GenerateObjectArgs<SCHEMA, OUTPUT, RESULT>,
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options?: Options,
): Promise<GenerateObjectResult<RESULT> & GenerationOutputMetadata> {
const { args, promptMessageId, order, fail, save, getSavedMessages } =
await this.start(ctx, generateObjectArgs, { ...threadOpts, ...options });
try {
const result = (await generateObject(
args,
)) as GenerateObjectResult<RESULT>;
await save({ object: result });
const metadata: GenerationOutputMetadata = {
promptMessageId,
order,
savedMessages: getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(result, metadata);
} catch (error) {
await fail(errorToString(error));
throw error;
}
}
/**
* This behaves like `streamObject` from the "ai" package except that
* it add context based on the userId and threadId and saves the input and
* resulting messages to the thread, if specified.
* Use {@link continueThread} to get a version of this function already scoped
* to a thread (and optionally userId).
*/
async streamObject<
SCHEMA extends ObjectSchema = DefaultObjectSchema,
OUTPUT extends ObjectMode = InferSchema<SCHEMA> extends string
? "enum"
: "object",
RESULT = OUTPUT extends "array"
? Array<InferSchema<SCHEMA>>
: InferSchema<SCHEMA>,
>(
ctx: ActionCtx & CustomCtx,
threadOpts: { userId?: string | null; threadId?: string },
/**
* The arguments to the streamObject function, similar to the ai `streamObject` function.
*/
streamObjectArgs: StreamObjectArgs<SCHEMA, OUTPUT, RESULT> & {
/**
* If provided, this message will be used as the "prompt" for the LLM call,
* instead of the prompt or messages.
* This is useful if you want to first save a user message, then use it as
* the prompt for the LLM call in another call.
*/
promptMessageId?: string;
/**
* The model to use for the LLM calls. This will override the model specified
* in the Agent constructor.
*/
model?: LanguageModel;
/**
* The tools to use for the tool calls. This will override tools specified
* in the Agent constructor or createThread / continueThread.
*/
},
/**
* The {@link ContextOptions} and {@link StorageOptions}
* options to use for fetching contextual messages and saving input/output messages.
*/
options?: Options,
): Promise<
ReturnType<typeof streamObject<SCHEMA, OUTPUT, RESULT>> &
GenerationOutputMetadata
> {
const { args, promptMessageId, order, fail, save, getSavedMessages } =
await this.start(ctx, streamObjectArgs, { ...threadOpts, ...options });
const stream = streamObject<SCHEMA, OUTPUT, RESULT>({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
...(args as any),
onError: async (error) => {
console.error(" streamObject onError", error);
// TODO: content that we have so far
// content: stream.fullStream.
await fail(errorToString(error.error));
return args.onError?.(error);
},
onFinish: async (result) => {
await save({
object: {
object: result.object,
finishReason: result.error ? "error" : "stop",
usage: result.usage,
warnings: result.warnings,
request: await stream.request,
response: result.response,
providerMetadata: result.providerMetadata,
toJsonResponse: stream.toTextStreamResponse,
},
});
return args.onFinish?.(result);
},
});
const metadata: GenerationOutputMetadata = {
promptMessageId,
order,
savedMessages: getSavedMessages(),
messageId: promptMessageId,
};
return Object.assign(stream, metadata);
}
/**
* Save a message to the thread.
* @param ctx A ctx object from a mutation or action.
* @param args The message and what to associate it with (user / thread)
* You can pass extra metadata alongside the message, e.g. associated fileIds.
* @returns The messageId of the saved message.
*/
async saveMessage(
ctx: RunMutationCtx,
args: SaveMessageArgs & {
/**
* If true, it will not generate embeddings for the message.
* Useful if you're saving messages in a mutation where you can't run `fetch`.
* You can generate them asynchronously by using the scheduler to run an
* action later that calls `agent.generateAndSaveEmbeddings`.
*/
skipEmbeddings?: boolean;
},
) {
const { messages } = await this.saveMessages(ctx, {
threadId: args.threadId,
userId: args.userId,
embeddings: args.embedding
? { model: args.embedding.model, vectors: [args.embedding.vector] }
: undefined,
messages:
args.prompt !== undefined
? [{ role: "user", content: args.prompt }]
: [args.message],
metadata: args.metadata ? [args.metadata] : undefined,
skipEmbeddings: args.skipEmbeddings,
pendingMessageId: args.pendingMessageId,
});
const message = messages.at(-1)!;
return { messageId: message._id, message };
}
/**
* Explicitly save messages associated with the thread (& user if provided)
* If you have an embedding model set, it will also generate embeddings for
* the messages.
* @param ctx The ctx parameter to a mutation or action.
* @param args The messages and context to save
* @returns
*/
async saveMessages(
ctx: RunMutationCtx | RunActionCtx,
args: SaveMessagesArgs & {
/**
* Skip generating embeddings for the messages. Useful if you're
* saving messages in a mutation where you can't run `fetch`.
* You can generate them asynchronously by using the scheduler to run an
* action later that calls `agent.generateAndSaveEmbeddings`.
*/
skipEmbeddings?: boolean;
},
): Promise<{ messages: MessageDoc[] }> {
let embeddings: { vectors: (number[] | null)[]; model: string } | undefined;
const { skipEmbeddings, ...rest } = args;
if (args.embeddings) {
embeddings = args.embeddings;
} else if (!skipEmbeddings && this.options.textEmbeddingModel) {
if (!("runAction" in ctx)) {
console.warn(
"You're trying to save messages and generate embeddings, but you're in a mutation. " +
"Pass `skipEmbeddings: true` to skip generating embeddings in the mutation and skip this warning. " +
"They will be generated lazily when you generate or stream text / objects. " +
"You can explicitly generate them asynchronously by using the scheduler to run an action later that calls `agent.generateAndSaveEmbeddings`.",
);
} else if ("workflowId" in ctx) {
console.warn(
"You're trying to save messages and generate embeddings, but you're in a workflow. " +
"Pass `skipEmbeddings: true` to skip generating embeddings in the workflow and skip this warning. " +
"They will be generated lazily when you generate or stream text / objects. " +
"You can explicitly generate them asynchronously by using the scheduler to run an action later that calls `agent.generateAndSaveEmbeddings`.",
);
} else {
embeddings = await this.generateEmbeddings(
ctx,
{ userId: args.userId ?? undefined, threadId: args.threadId },
args.messages,
);
}
}
return saveMessages(ctx, this.component, {
...rest,
agentName: this.options.name,
embeddings,
});
}
/**
* List messages from a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to list messages from.
* @param args.paginationOpts Pagination options (e.g. via usePaginatedQuery).
* @param args.excludeToolMessages Whether to exclude tool messages.
* False by default.
* @param args.statuses What statuses to include. All by default.
* @returns The MessageDoc's in a format compatible with usePaginatedQuery.
*/
async listMessages(
ctx: RunQueryCtx,
args: {
threadId: string;
paginationOpts: PaginationOptions;
excludeToolMessages?: boolean;
statuses?: MessageStatus[];
},
): Promise<PaginationResult<MessageDoc>> {
return listMessages(ctx, this.component, args);
}
/**
* A function that handles fetching stream deltas, used with the React hooks
* `useThreadMessages` or `useStreamingThreadMessages`.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to sync streams for.
* @param args.streamArgs The stream arguments with per-stream cursors.
* @returns The deltas for each stream from their existing cursor.
*/
async syncStreams(
ctx: RunQueryCtx,
args: {
threadId: string;
streamArgs: StreamArgs | undefined;
// By default, only streaming messages are included.
includeStatuses?: ("streaming" | "finished" | "aborted")[];
},
): Promise<SyncStreamsReturnValue | undefined> {
return syncStreams(ctx, this.component, args);
}
/**
* Fetch the context messages for a thread.
* @param ctx Either a query, mutation, or action ctx.
* If it is not an action context, you can't do text or
* vector search.
* @param args The associated thread, user, message
* @returns
*/
async fetchContextMessages(
ctx: RunQueryCtx | RunActionCtx,
args: {
userId: string | undefined;
threadId: string | undefined;
messages: (ModelMessage | Message)[];
/**
* If provided, it will search for messages up to and including this message.
* Note: if this is far in the past, text and vector search results may be more
* limited, as it's post-filtering the results.
*/
upToAndIncludingMessageId?: string;
contextOptions: ContextOptions | undefined;
},
): Promise<MessageDoc[]> {
assert(args.userId || args.threadId, "Specify userId or threadId");
const contextOptions = {
...this.options.contextOptions,
...args.contextOptions,
};
return fetchContextMessages(ctx, this.component, {
...args,
contextOptions,
getEmbedding: async (text) => {
assert("runAction" in ctx);
assert(
this.options.textEmbeddingModel,
"A textEmbeddingModel is required to be set on the Agent that you're doing vector search with",
);
return {
embedding: (
await this.doEmbed(ctx, {
userId: args.userId,
threadId: args.threadId,
values: [text],
})
).embeddings[0],
textEmbeddingModel: this.options.textEmbeddingModel,
};
},
});
}
/**
* Get the metadata for a thread.
* @param ctx A ctx object from a query, mutation, or action.
* @param args.threadId The thread to get the metadata for.
* @returns The metadata for the thread.
*/
async getThreadMetadata(
ctx: RunQueryCtx,
args: { threadId: string },
): Promise<ThreadDoc> {
return getThreadMetadata(ctx, this.component, args);
}
/**
* Update the metadata for a thread.
* @param ctx A ctx object from a mutation or action.
* @param args.threadId The thread to update the metadata for.
* @param args.patch The patch to apply to the thread.
* @returns The updated thread metadata.
*/
async updateThreadMetadata(
ctx: RunMutationCtx,
args: {
threadId: string;
patch: Partial<
Pick<ThreadDoc, (typeof threadFieldsSupportingPatch)[number]>
>;
},
): Promise<ThreadDoc> {
const thread = await ctx.runMutation(
this.component.threads.updateThread,
args,
);
return thread;
}
/**
* Get the embeddings for a set of messages.
* @param messages The messages to get the embeddings for.
* @returns The embeddings for the messages.
*/
async generateEmbeddings(
ctx: RunActionCtx,
{
userId,
threadId,
}: { userId: string | undefined; threadId: string | undefined },
messages: (ModelMessage | Message)[],
) {
if (!this.options.textEmbeddingModel) {
return undefined;
}
let embeddings:
| {
vectors: (number[] | null)[];
dimension: VectorDimension;
model: string;
}
| undefined;
const messageTexts = messages.map((m) => !isTool(m) && extractText(m));
// Find the indexes of the messages that have text.
const textIndexes = messageTexts
.map((t, i) => (t ? i : undefined))
.filter((i) => i !== undefined);
if (textIndexes.length === 0) {
return undefined;
}
const values = messageTexts.filter((t): t is string => !!t);
// Then embed those messages.
const textEmbeddings = await this.doEmbed(ctx, {
userId,
threadId,
values,
});
// Then assemble the embeddings into a single array with nulls for the messages without text.
const embeddingsOrNull = Array(messages.length).fill(null);
textIndexes.forEach((i, j) => {
embeddingsOrNull[i] = textEmbeddings.embeddings[j];
});
if (textEmbeddings.embeddings.length > 0) {
const dimension = textEmbeddings.embeddings[0].length;
validateVectorDimension(dimension);
const model = getModelName(this.options.textEmbeddingModel);
embeddings = { vectors: embeddingsOrNull, dimension, model };
}
return embeddings;
}
/**
* Generate embeddings for a set of messages, and save them to the database.
* It will not generate or save embeddings for messages that already have an
* embedding.
* @param ctx The ctx parameter to an action.
* @param args The messageIds to generate embeddings for.
*/
async generateAndSaveEmbeddings(
ctx: RunActionCtx,
args: { messageIds: string[] },
) {
const messages = (
await ctx.runQuery(this.component.messages.getMessagesByIds, {
messageIds: args.messageIds,
})
).filter((m): m is NonNullable<typeof m> => m !== null);
if (messages.length !== args.messageIds.length) {
throw new Error(
"Some messages were not found: " +
args.messageIds
.filter((id) => !messages.some((m) => m?._id === id))
.join(", "),
);
}
await this._generateAndSaveEmbeddings(ctx, messages);
}
async _generateAndSaveEmbeddings(ctx: RunActionCtx, messages: MessageDoc[]) {
if (messages.some((m) => !m.message)) {
throw new Error(
"Some messages don't have a message: " +
messages
.filter((m) => !m.message)
.map((m) => m._id)
.join(", "),
);
}
const messagesMissingEmbeddings = messages.filter((m) => !m.embeddingId);
if (messagesMissingEmbeddings.length === 0) {
return;
}
const embeddings = await this.generateEmbeddings(
ctx,
{
userId: messagesMissingEmbeddings[0]!.userId,
threadId: messagesMissingEmbeddings[0]!.threadId,
},
messagesMissingEmbeddings.map((m) => deserializeMessage(m!.message!)),
);
if (!embeddings) {
if (!this.options.textEmbeddingModel) {
throw new Error(
"No embeddings were generated for the messages. You must pass a textEmbeddingModel to the agent constructor.",
);
}
throw new Error(
"No embeddings were generated for these messages: " +
messagesMissingEmbeddings.map((m) => m!._id).join(", "),
);
}
await ctx.runMutation(this.component.vector.index.insertBatch, {
vectorDimension: embeddings.dimension,
vectors: messagesMissingEmbeddings
.map((m, i) => ({
messageId: m!._id,
model: embeddings.model,
table: "messages",
userId: m.userId,
threadId: m.threadId,
vector: embeddings.vectors[i],
}))
.filter(
(v): v is Extract<typeof v, { vector: number[] }> =>
v.vector !== null,
),
});
}
/**
* Explicitly save a "step" created by the AI SDK.
* @param ctx The ctx argument to a mutation or action.
* @param args The Step generated by the AI SDK.
*/
async saveStep<TOOLS extends ToolSet>(
ctx: ActionCtx,
args: {
userId?: string;
threadId: string;
/**
* The message this step is in response to.
*/
promptMessageId: string;
/**
* The step to save, possibly including multiple tool calls.
*/
step: StepResult<TOOLS>;
/**
* The model used to generate the step.
* Defaults to the chat model for the Agent.
*/
model?: string;
/**
* The provider of the model used to generate the step.
* Defaults to the chat provider for the Agent.
*/
provider?: string;
},
): Promise<{ messages: MessageDoc[] }> {
const { messages } = await serializeNewMessagesInStep(
ctx,
this.component,
args.step,
{
provider: args.provider ?? getProviderName(this.options.languageModel),
model: args.model ?? getModelName(this.options.languageModel),
},
);
const embeddings = await this.generateEmbeddings(
ctx,
{ userId: args.userId, threadId: args.threadId },
messages.map((m) => m.message),
);
return ctx.runMutation(this.component.messages.addMessages, {
userId: args.userId,
threadId: args.threadId,
agentName: this.options.name,
promptMessageId: args.promptMessageId,
messages,
embeddings,
failPendingSteps: false,
});
}
/**
* Manually save the result of a generateObject call to the thread.
* This happens automatically when using {@link generateObject} or {@link streamObject}
* from the `thread` object created by {@link continueThread} or {@link createThread}.
* @param ctx The context passed from the mutation or action function calling this.
* @param args The arguments to the saveObject function.
*/
async saveObject(
ctx: ActionCtx,
args: {
userId: string | undefined;
threadId: string;
promptMessageId: string;
model: string | undefined;
provider: string | undefined;
result: GenerateObjectResult<unknown>;
metadata?: Omit<MessageWithMetadata, "message">;
},
): Promise<{ messages: MessageDoc[] }> {
const { messages } = await serializeObjectResult(
ctx,
this.component,
args.result,
{
model:
args.model ??
args.metadata?.model ??
getModelName(this.options.languageModel),
provider:
args.provider ??
args.metadata?.provider ??
getProviderName(this.options.languageModel),
},
);
const embeddings = await this.generateEmbeddings(
ctx,
{ userId: args.userId, threadId: args.threadId },
messages.map((m) => m.message),
);
return ctx.runMutation(this.component.messages.addMessages, {
userId: args.userId,
threadId: args.threadId,
promptMessageId: args.promptMessageId,
failPendingSteps: false,
messages,
embeddings,
agentName: this.options.name,
});
}
/**
* Commit or rollback a message that was pending.
* This is done automatically when saving messages by default.
* If creating pending messages, you can call this when the full "transaction" is done.
* @param ctx The ctx argument to your mutation or action.
* @param args What message to save. Generally the parent message sent into
* the generateText call.
*/
async finalizeMessage(
ctx: RunMutationCtx,
args: {
messageId: string;
result: { status: "failed"; error: string } | { status: "success" };
},
): Promise<void> {
await ctx.runMutation(this.component.messages.finalizeMessage, {
messageId: args.messageId,
result: args.result,
});
}
/**
* Update a message by its id.
* @param ctx The ctx argument to your mutation or action.
* @param args The message fields to update.
*/
async updateMessage(
ctx: RunMutationCtx,
args: {
/** The id of the message to update. */
messageId: string;
patch: {
/** The message to replace the existing message. */
message: ModelMessage | Message;
/** The status to set on the message. */
status: "success" | "error";
/** The error message to set on the message. */
error?: string;
/**
* These will override the fileIds in the message.
* To remove all existing files, pass an empty array.
* If passing in a new message, pass in the fileIds you explicitly want to keep
* from the previous message, as the new files generated from the new message
* will be added to the list.
* If you pass undefined, it will not change the fileIds unless new
* files are generated from the message. In that case, the new fileIds
* will replace the old fileIds.
*/
fileIds?: string[];
};
},
): Promise<void> {
const { message, fileIds } = await serializeMessage(
ctx,
this.component,
args.patch.message,
);
await ctx.runMutation(this.component.messages.updateMessage, {
messageId: args.messageId,
patch: {
message,
fileIds: args.patch.fileIds
? [...args.patch.fileIds, ...(fileIds ?? [])