@convex-dev/agent
Version:
A agent component for Convex.
412 lines • 15.7 kB
JavaScript
import { readUIMessageStream, } from "ai";
import { assert, pick } from "convex-helpers";
import {} from "./UIMessages.js";
import { joinText, sorted } from "./shared.js";
import {} from "./validators.js";
import { getErrorMessage } from "@ai-sdk/provider-utils";
export function blankUIMessage(streamMessage, threadId) {
return {
id: `stream:${streamMessage.streamId}`,
key: `${threadId}-${streamMessage.order}-${streamMessage.stepOrder}`,
order: streamMessage.order,
stepOrder: streamMessage.stepOrder,
status: statusFromStreamStatus(streamMessage.status),
agentName: streamMessage.agentName,
text: "",
_creationTime: Date.now(),
role: "assistant",
parts: [],
...(streamMessage.metadata ? { metadata: streamMessage.metadata } : {}),
};
}
export function statusFromStreamStatus(status) {
switch (status) {
case "streaming":
return "streaming";
case "finished":
return "success";
case "aborted":
return "failed";
default:
return "pending";
}
}
export async function updateFromUIMessageChunks(uiMessage, parts) {
const partsStream = new ReadableStream({
start(controller) {
for (const part of parts) {
controller.enqueue(part);
}
controller.close();
},
});
let failed = false;
const messageStream = readUIMessageStream({
message: uiMessage,
stream: partsStream,
onError: (e) => {
failed = true;
console.error("Error in stream", e);
},
terminateOnError: true,
});
let message = uiMessage;
for await (const messagePart of messageStream) {
assert(messagePart.id === message.id, `Expecting to only make one UIMessage in a stream`);
message = messagePart;
}
if (failed) {
message.status = "failed";
}
message.text = joinText(message.parts);
return message;
}
export async function deriveUIMessagesFromDeltas(threadId, streamMessages, allDeltas) {
const messages = [];
for (const streamMessage of streamMessages) {
if (streamMessage.format === "UIMessageChunk") {
const { parts } = getParts(allDeltas.filter((d) => d.streamId === streamMessage.streamId), 0);
const uiMessage = await updateFromUIMessageChunks(blankUIMessage(streamMessage, threadId), parts);
// TODO: this fails on partial tool calls
messages.push(uiMessage);
}
else {
const [uiMessages] = deriveUIMessagesFromTextStreamParts(threadId, [streamMessage], [], allDeltas);
messages.push(...uiMessages);
}
}
return sorted(messages);
}
/**
*
*/
export function deriveUIMessagesFromTextStreamParts(threadId, streamMessages, existingStreams, allDeltas) {
const newStreams = [];
// Seed the existing chunks
let changed = false;
for (const streamMessage of streamMessages) {
const deltas = allDeltas.filter((d) => d.streamId === streamMessage.streamId);
const existing = existingStreams.find((s) => s.streamId === streamMessage.streamId);
const [newStream, messageChanged] = updateFromTextStreamParts(threadId, streamMessage, existing, deltas);
newStreams.push(newStream);
if (messageChanged)
changed = true;
}
for (const { streamId } of existingStreams) {
if (!newStreams.find((s) => s.streamId === streamId)) {
// There's a stream that's no longer active.
changed = true;
}
}
const messages = sorted(newStreams.map((s) => s.message));
return [messages, newStreams, changed];
}
export function getParts(deltas, fromCursor) {
const parts = [];
let cursor = fromCursor ?? 0;
for (const delta of deltas.sort((a, b) => a.start - b.start)) {
if (delta.parts.length === 0) {
console.debug(`Got delta with no parts: ${JSON.stringify(delta)}`);
continue;
}
if (cursor !== delta.start) {
if (cursor >= delta.end) {
continue;
}
else if (cursor < delta.start) {
console.warn(`Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`);
break;
}
else {
throw new Error(`Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`);
}
}
parts.push(...delta.parts);
cursor = delta.end;
}
return { parts, cursor };
}
/**
* This is historically from when we would use the onChunk callback instead of
* consuming the full UIMessageStream.
*/
// exported for testing
export function updateFromTextStreamParts(threadId, streamMessage, existing, deltas) {
const { cursor, parts } = getParts(deltas, existing?.cursor);
const changed = parts.length > 0 ||
(existing &&
statusFromStreamStatus(streamMessage.status) !== existing.message.status);
const existingMessage = existing?.message ?? blankUIMessage(streamMessage, threadId);
if (!changed) {
return [
existing ?? {
streamId: streamMessage.streamId,
cursor,
message: existingMessage,
},
false,
];
}
const message = structuredClone(existingMessage);
message.status = statusFromStreamStatus(streamMessage.status);
const textPartsById = new Map();
const toolPartsById = new Map(message.parts
.filter((p) => p.type.startsWith("tool-") || p.type === "dynamic-tool")
.map((p) => [p.toolCallId, p]));
const reasoningPartsById = new Map();
for (const part of parts) {
switch (part.type) {
case "text-start":
case "text-delta": {
if (!textPartsById.has(part.id)) {
const lastPart = message.parts.at(-1);
if (lastPart?.type === "text") {
textPartsById.set(part.id, lastPart);
}
else {
const newPart = {
type: "text",
text: "",
providerMetadata: part.providerMetadata,
};
textPartsById.set(part.id, newPart);
message.parts.push(newPart);
}
}
if (part.type === "text-delta") {
const textPart = textPartsById.get(part.id);
textPart.text += part.text;
textPart.providerMetadata = mergeProviderMetadata(textPart.providerMetadata, part.providerMetadata);
}
break;
}
case "tool-input-start": {
let newPart;
if (part.dynamic) {
newPart = {
type: "dynamic-tool",
toolCallId: part.id,
toolName: part.toolName,
state: "input-streaming",
input: "",
};
}
else {
newPart = {
type: `tool-${part.toolName}`,
toolCallId: part.id,
state: "input-streaming",
input: "",
providerExecuted: part.providerExecuted,
};
}
toolPartsById.set(part.id, newPart);
message.parts.push(newPart);
break;
}
case "tool-input-delta":
{
const toUpdate = toolPartsById.get(part.id);
assert(toUpdate, `Expected to find tool call part ${part.id} to update`);
toUpdate.input = (toUpdate.input ?? "") + part.delta;
}
break;
case "tool-input-end":
{
const toUpdate = toolPartsById.get(part.id);
assert(toUpdate, `Expected to find tool call part ${part.id} to update`);
toUpdate.state = "input-available";
if (part.providerMetadata) {
const updatable = toUpdate;
updatable.callProviderMetadata = mergeProviderMetadata(updatable.callProviderMetadata, part.providerMetadata);
}
}
break;
case "tool-call": {
let newPart;
if (part.dynamic) {
newPart = {
type: "dynamic-tool",
toolCallId: part.toolCallId,
toolName: part.toolName,
input: part.input,
state: "input-available",
};
}
else {
newPart = {
type: `tool-${part.toolName}`,
toolCallId: part.toolCallId,
input: part.input,
state: "input-available",
};
if (part.providerExecuted) {
newPart.providerExecuted = part.providerExecuted;
}
}
if (part.providerMetadata) {
newPart.callProviderMetadata = part.providerMetadata;
}
if (toolPartsById.has(part.toolCallId)) {
const toUpdate = toolPartsById.get(part.toolCallId);
Object.assign(toUpdate, newPart);
}
else {
toolPartsById.set(part.toolCallId, newPart);
message.parts.push(newPart);
}
break;
}
case "tool-result": {
const toolCall = toolPartsById.get(part.toolCallId);
assert(toolCall, `Expected to find tool call part ${part.toolCallId} to update with result`);
let newPart;
if (toolCall.type === "dynamic-tool") {
newPart = {
...toolCall,
state: "output-available",
input: part.input ?? toolCall.input,
output: part.output ?? toolCall.output,
...pick(part, ["preliminary"]),
};
}
else {
newPart = {
...toolCall,
state: "output-available",
input: part.input ?? toolCall.input,
output: part.output ?? toolCall.output,
preliminary: part.preliminary,
};
}
Object.assign(toolCall, newPart);
break;
}
case "reasoning-start":
case "reasoning-delta": {
if (!reasoningPartsById.has(part.id)) {
const lastPart = message.parts.at(-1);
if (lastPart?.type === "reasoning") {
reasoningPartsById.set(part.id, lastPart);
}
else {
const newPart = {
type: "reasoning",
state: "streaming",
text: "",
providerMetadata: part.providerMetadata,
};
reasoningPartsById.set(part.id, newPart);
message.parts.push(newPart);
}
}
const reasoningPart = reasoningPartsById.get(part.id);
if (part.type === "reasoning-delta") {
reasoningPart.text += part.text;
reasoningPart.providerMetadata = mergeProviderMetadata(reasoningPart.providerMetadata, part.providerMetadata);
}
break;
}
case "reasoning-end": {
const reasoningPart = reasoningPartsById.get(part.id) ??
message.parts.find((p) => p.type === "reasoning" && p.state === "streaming");
if (reasoningPart) {
reasoningPart.state = "done";
}
else {
console.warn(`Expected to find reasoning part ${part.id} to finish, but found none`);
}
break;
}
case "source":
if (part.sourceType === "url") {
message.parts.push({
type: "source-url",
url: part.url,
sourceId: part.id,
providerMetadata: part.providerMetadata,
title: part.title,
});
}
else if (part.sourceType === "document") {
message.parts.push({
type: "source-document",
mediaType: part.mediaType,
sourceId: part.id,
title: part.title,
filename: part.filename,
providerMetadata: part.providerMetadata,
});
}
else {
console.warn("Got source part with unknown source type", part);
}
break;
case "abort":
message.status = "failed";
break;
case "error":
message.status = "failed";
console.warn("Generation failed with error", part.error);
break;
case "tool-error": {
const toolPart = toolPartsById.get(part.toolCallId);
if (toolPart) {
toolPart.errorText = getErrorMessage(part.error);
}
break;
}
case "file":
case "text-end":
case "finish-step":
case "finish":
case "raw":
case "start-step":
case "start":
// ignore
break;
default: {
// Should never happen
const _ = part;
console.warn(`Received unexpected part: ${JSON.stringify(part)}`);
break;
}
}
}
// Consider reasoning done once something else happens
for (let i = 0; i < message.parts.length - 1; i++) {
const part = message.parts[i];
if (part.type === "reasoning") {
part.state = "done";
}
}
message.text = joinText(message.parts);
return [
{
streamId: streamMessage.streamId,
cursor,
message,
},
true,
];
}
function mergeProviderMetadata(existing, part) {
if (!existing && !part) {
return undefined;
}
if (!existing) {
return part;
}
if (!part) {
return existing;
}
const merged = existing;
for (const [provider, metadata] of Object.entries(part)) {
merged[provider] = {
...merged[provider],
...metadata,
};
}
return merged;
}
//# sourceMappingURL=deltas.js.map