@convex-dev/agent
Version:
A agent component for Convex.
417 lines (407 loc) • 11.9 kB
text/typescript
import type { TextStreamPart, ToolSet } from "ai";
import type { MessageDoc } from "../client/index.js";
import type {
Message,
MessageStatus,
StreamDelta,
StreamMessage,
vReasoningPart,
vTextPart,
vToolCallPart,
vToolResultPart,
} from "../validators.js";
import type { Infer } from "convex/values";
export function mergeDeltas(
threadId: string,
streamMessages: StreamMessage[],
existingStreams: Array<{
streamId: string;
cursor: number;
messages: MessageDoc[];
}>,
allDeltas: StreamDelta[],
): [
MessageDoc[],
Array<{ streamId: string; cursor: number; messages: MessageDoc[] }>,
boolean,
] {
const newStreams: Array<{
streamId: string;
cursor: number;
messages: MessageDoc[];
}> = [];
// Seed the existing chunks
let changed = false;
for (const streamMessage of streamMessages) {
const deltas = allDeltas.filter(
(d) => d.streamId === streamMessage.streamId,
);
const existing = existingStreams.find(
(s) => s.streamId === streamMessage.streamId,
);
const [newStream, messageChanged] = applyDeltasToStreamMessage(
threadId,
streamMessage,
existing,
deltas,
);
newStreams.push(newStream);
if (messageChanged) changed = true;
}
for (const { streamId } of existingStreams) {
if (!newStreams.find((s) => s.streamId === streamId)) {
// There's a stream that's no longer active.
changed = true;
}
}
const messages = newStreams
.map((s) => s.messages)
.flat()
.sort((a, b) => a.order - b.order || a.stepOrder - b.stepOrder);
return [messages, newStreams, changed];
}
// exported for testing
export function applyDeltasToStreamMessage(
threadId: string,
streamMessage: StreamMessage,
existing:
| { streamId: string; cursor: number; messages: MessageDoc[] }
| undefined,
deltas: StreamDelta[],
): [{ streamId: string; cursor: number; messages: MessageDoc[] }, boolean] {
let changed = false;
let cursor = existing?.cursor ?? 0;
let parts: TextStreamPart<ToolSet>[] = [];
for (const delta of deltas.sort((a, b) => a.start - b.start)) {
if (delta.parts.length === 0) {
console.warn(`Got delta with no parts: ${JSON.stringify(delta)}`);
continue;
}
if (cursor !== delta.start) {
if (cursor >= delta.end) {
console.debug(
`Got duplicate delta for stream ${delta.streamId} at ${delta.start}`,
);
continue;
} else if (cursor < delta.start) {
console.warn(
`Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`,
);
continue;
} else {
throw new Error(
`Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`,
);
}
}
changed = true;
cursor = delta.end;
parts.push(...delta.parts);
}
if (existing && existing.messages.length > 0 && !changed) {
const lastMessage = existing.messages.at(-1)!;
if (statusFromStreamStatus(streamMessage.status) !== lastMessage.status) {
changed = true;
}
}
if (!changed) {
return [
existing ?? { streamId: streamMessage.streamId, cursor, messages: [] },
false,
];
}
const existingMessages = existing?.messages ?? [];
let currentMessage: MessageDoc;
if (existingMessages.length > 0) {
// replace the last message with a new one
const lastMessage = existingMessages.at(-1)!;
currentMessage = {
...lastMessage,
message: cloneMessageAndContent(lastMessage.message),
status: statusFromStreamStatus(streamMessage.status),
};
} else {
const newMessage = createStreamingMessage(
threadId,
streamMessage,
parts[0]!,
existingMessages.length,
);
parts = parts.slice(1);
currentMessage = newMessage;
}
const newStream = {
streamId: streamMessage.streamId,
cursor,
messages: [...existingMessages.slice(0, -1), currentMessage],
};
let lastContent = getLastContent(currentMessage);
for (const part of parts) {
let contentToAdd:
| Infer<typeof vTextPart>
| Infer<typeof vToolCallPart>
| Infer<typeof vToolResultPart>
| Infer<typeof vReasoningPart>
| undefined;
const isToolRole = part.type === "source" || part.type === "tool-result";
if (isToolRole !== (currentMessage.message!.role === "tool")) {
currentMessage = createStreamingMessage(
threadId,
streamMessage,
part,
newStream.messages.length,
);
lastContent = getLastContent(currentMessage);
newStream.messages.push(currentMessage);
continue;
}
switch (part.type) {
case "text-delta": {
const text = part.text;
currentMessage.text = (currentMessage.text ?? "") + text;
if (lastContent?.type === "text") {
lastContent.text = (lastContent.text ?? "") + text;
} else {
contentToAdd = { type: "text", text } satisfies Infer<
typeof vTextPart
>;
}
break;
}
case "tool-input-start": {
const toolCallId = part.id;
currentMessage.tool = true;
contentToAdd = {
type: "tool-call",
toolCallId,
toolName: part.toolName,
args: "",
providerExecuted:
"providerExecuted" in part ? part.providerExecuted : undefined,
providerOptions:
"providerMetadata" in part ? part.providerMetadata : undefined,
} satisfies Infer<typeof vToolCallPart>;
break;
}
case "tool-input-delta":
{
currentMessage.tool = true;
if (lastContent?.type !== "tool-call") {
throw new Error("Expected last content to be a tool call");
}
if (typeof lastContent.args !== "string") {
throw new Error("Expected args to be a string");
}
const delta =
"argsTextDelta" in part ? part.argsTextDelta : part.delta;
lastContent.args = (lastContent.args ?? "") + delta;
}
break;
case "tool-call": {
currentMessage.tool = true;
contentToAdd = toolCallContent(part);
break;
}
case "tool-result": {
contentToAdd = toolResultContent(part);
break;
}
case "reasoning-delta": {
currentMessage.reasoning = (currentMessage.reasoning ?? "") + part.text;
if (lastContent?.type === "reasoning") {
lastContent.text = (lastContent.text ?? "") + part.text;
} else {
contentToAdd = {
type: "reasoning",
text: part.text,
providerOptions:
"providerMetadata" in part ? part.providerMetadata : undefined,
state: "streaming",
} satisfies Infer<typeof vReasoningPart>;
}
break;
}
case "source":
if (!currentMessage.sources) {
currentMessage.sources = [];
}
currentMessage.sources.push(part);
break;
case "raw":
// ignore
break;
default:
console.warn(`Received unexpected part: ${JSON.stringify(part)}`);
break;
}
if (contentToAdd) {
if (!currentMessage.message!.content) {
currentMessage.message!.content = [];
}
if (!Array.isArray(currentMessage.message?.content)) {
throw new Error("Expected message content to be an array");
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
currentMessage.message.content.push(contentToAdd as any);
lastContent = contentToAdd;
}
}
return [newStream, true];
}
function toolCallContent(
part: Extract<TextStreamPart<ToolSet>, { type: "tool-call" }>,
): Infer<typeof vToolCallPart> {
const args = "args" in part ? part.args : part.input;
return {
type: "tool-call",
toolCallId: part.toolCallId,
toolName: part.toolName,
args,
providerExecuted: part.providerExecuted,
} satisfies Infer<typeof vToolCallPart>;
}
function toolResultContent(
part: Extract<TextStreamPart<ToolSet>, { type: "tool-result" }>,
): Infer<typeof vToolResultPart> {
return {
type: "tool-result",
toolCallId: part.toolCallId,
toolName: part.toolName,
result: part.output,
args: part.input,
providerExecuted: part.providerExecuted,
} satisfies Infer<typeof vToolResultPart>;
}
function cloneMessageAndContent(
message: Message | undefined,
): Message | undefined {
return (
message &&
({
...message,
content: Array.isArray(message.content)
? [...message.content]
: message.content,
} as typeof message)
);
}
function getLastContent(message: MessageDoc) {
if (Array.isArray(message.message?.content)) {
return message.message.content.at(-1);
}
return undefined;
}
function statusFromStreamStatus(
status: StreamMessage["status"],
): MessageStatus {
switch (status) {
case "streaming":
return "pending";
case "finished":
return "success";
case "aborted":
return "failed";
default:
return "pending";
}
}
// TODO: share more code with applyDeltasToStreamMessage
export function createStreamingMessage(
threadId: string,
message: StreamMessage,
part: TextStreamPart<ToolSet>,
index: number,
): MessageDoc {
const { streamId, ...rest } = message;
const metadata: MessageDoc = {
...rest,
_id: `${streamId}-${index}`,
_creationTime: Date.now(),
status: statusFromStreamStatus(message.status),
threadId,
tool: false,
};
switch (part.type) {
case "text-delta": {
const text = part.text || "";
return {
...metadata,
message: { role: "assistant", content: [{ type: "text", text }] },
text,
};
}
case "tool-input-start": {
return {
...metadata,
tool: true,
message: {
role: "assistant",
content: [
{
type: "tool-call",
toolName: part.toolName,
toolCallId: part.id,
args: "", // when it's a string, it's a partial call
providerExecuted:
"providerExecuted" in part ? part.providerExecuted : undefined,
providerOptions:
"providerMetadata" in part ? part.providerMetadata : undefined,
},
],
},
};
}
case "tool-input-delta": {
console.warn("Received tool call delta part first??");
const delta = part.delta;
const toolCallId = part.id;
const toolName = part.type.slice("tool-".length);
return {
...metadata,
tool: true,
message: {
role: "assistant",
content: [{ type: "tool-call", toolCallId, toolName, args: delta }],
},
};
}
case "tool-call": {
return {
...metadata,
tool: true,
message: { role: "assistant", content: [toolCallContent(part)] },
};
}
case "tool-result":
return {
...metadata,
tool: true,
message: { role: "tool", content: [toolResultContent(part)] },
};
case "reasoning-delta": {
return {
...metadata,
message: {
role: "assistant",
content: [{ type: "reasoning", text: part.text }],
},
reasoning: part.text,
};
}
case "source":
console.warn("Received source part first??");
return {
...metadata,
tool: true,
message: { role: "tool", content: [] },
sources: [part],
};
// case "raw":
// return {
// ...metadata,
// message: { role: "assistant", content: [part.rawValue] },
// };
default:
throw new Error(`Unexpected part type: ${JSON.stringify(part)}`);
}
}