UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

412 lines 15.7 kB
import { readUIMessageStream, } from "ai"; import { assert, pick } from "convex-helpers"; import {} from "./UIMessages.js"; import { joinText, sorted } from "./shared.js"; import {} from "./validators.js"; import { getErrorMessage } from "@ai-sdk/provider-utils"; export function blankUIMessage(streamMessage, threadId) { return { id: `stream:${streamMessage.streamId}`, key: `${threadId}-${streamMessage.order}-${streamMessage.stepOrder}`, order: streamMessage.order, stepOrder: streamMessage.stepOrder, status: statusFromStreamStatus(streamMessage.status), agentName: streamMessage.agentName, text: "", _creationTime: Date.now(), role: "assistant", parts: [], ...(streamMessage.metadata ? { metadata: streamMessage.metadata } : {}), }; } export function statusFromStreamStatus(status) { switch (status) { case "streaming": return "streaming"; case "finished": return "success"; case "aborted": return "failed"; default: return "pending"; } } export async function updateFromUIMessageChunks(uiMessage, parts) { const partsStream = new ReadableStream({ start(controller) { for (const part of parts) { controller.enqueue(part); } controller.close(); }, }); let failed = false; const messageStream = readUIMessageStream({ message: uiMessage, stream: partsStream, onError: (e) => { failed = true; console.error("Error in stream", e); }, terminateOnError: true, }); let message = uiMessage; for await (const messagePart of messageStream) { assert(messagePart.id === message.id, `Expecting to only make one UIMessage in a stream`); message = messagePart; } if (failed) { message.status = "failed"; } message.text = joinText(message.parts); return message; } export async function deriveUIMessagesFromDeltas(threadId, streamMessages, allDeltas) { const messages = []; for (const streamMessage of streamMessages) { if (streamMessage.format === "UIMessageChunk") { const { parts } = getParts(allDeltas.filter((d) => d.streamId === streamMessage.streamId), 0); const uiMessage = await updateFromUIMessageChunks(blankUIMessage(streamMessage, threadId), parts); // TODO: this fails on partial tool calls messages.push(uiMessage); } else { const [uiMessages] = deriveUIMessagesFromTextStreamParts(threadId, [streamMessage], [], allDeltas); messages.push(...uiMessages); } } return sorted(messages); } /** * */ export function deriveUIMessagesFromTextStreamParts(threadId, streamMessages, existingStreams, allDeltas) { const newStreams = []; // Seed the existing chunks let changed = false; for (const streamMessage of streamMessages) { const deltas = allDeltas.filter((d) => d.streamId === streamMessage.streamId); const existing = existingStreams.find((s) => s.streamId === streamMessage.streamId); const [newStream, messageChanged] = updateFromTextStreamParts(threadId, streamMessage, existing, deltas); newStreams.push(newStream); if (messageChanged) changed = true; } for (const { streamId } of existingStreams) { if (!newStreams.find((s) => s.streamId === streamId)) { // There's a stream that's no longer active. changed = true; } } const messages = sorted(newStreams.map((s) => s.message)); return [messages, newStreams, changed]; } export function getParts(deltas, fromCursor) { const parts = []; let cursor = fromCursor ?? 0; for (const delta of deltas.sort((a, b) => a.start - b.start)) { if (delta.parts.length === 0) { console.debug(`Got delta with no parts: ${JSON.stringify(delta)}`); continue; } if (cursor !== delta.start) { if (cursor >= delta.end) { continue; } else if (cursor < delta.start) { console.warn(`Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`); break; } else { throw new Error(`Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`); } } parts.push(...delta.parts); cursor = delta.end; } return { parts, cursor }; } /** * This is historically from when we would use the onChunk callback instead of * consuming the full UIMessageStream. */ // exported for testing export function updateFromTextStreamParts(threadId, streamMessage, existing, deltas) { const { cursor, parts } = getParts(deltas, existing?.cursor); const changed = parts.length > 0 || (existing && statusFromStreamStatus(streamMessage.status) !== existing.message.status); const existingMessage = existing?.message ?? blankUIMessage(streamMessage, threadId); if (!changed) { return [ existing ?? { streamId: streamMessage.streamId, cursor, message: existingMessage, }, false, ]; } const message = structuredClone(existingMessage); message.status = statusFromStreamStatus(streamMessage.status); const textPartsById = new Map(); const toolPartsById = new Map(message.parts .filter((p) => p.type.startsWith("tool-") || p.type === "dynamic-tool") .map((p) => [p.toolCallId, p])); const reasoningPartsById = new Map(); for (const part of parts) { switch (part.type) { case "text-start": case "text-delta": { if (!textPartsById.has(part.id)) { const lastPart = message.parts.at(-1); if (lastPart?.type === "text") { textPartsById.set(part.id, lastPart); } else { const newPart = { type: "text", text: "", providerMetadata: part.providerMetadata, }; textPartsById.set(part.id, newPart); message.parts.push(newPart); } } if (part.type === "text-delta") { const textPart = textPartsById.get(part.id); textPart.text += part.text; textPart.providerMetadata = mergeProviderMetadata(textPart.providerMetadata, part.providerMetadata); } break; } case "tool-input-start": { let newPart; if (part.dynamic) { newPart = { type: "dynamic-tool", toolCallId: part.id, toolName: part.toolName, state: "input-streaming", input: "", }; } else { newPart = { type: `tool-${part.toolName}`, toolCallId: part.id, state: "input-streaming", input: "", providerExecuted: part.providerExecuted, }; } toolPartsById.set(part.id, newPart); message.parts.push(newPart); break; } case "tool-input-delta": { const toUpdate = toolPartsById.get(part.id); assert(toUpdate, `Expected to find tool call part ${part.id} to update`); toUpdate.input = (toUpdate.input ?? "") + part.delta; } break; case "tool-input-end": { const toUpdate = toolPartsById.get(part.id); assert(toUpdate, `Expected to find tool call part ${part.id} to update`); toUpdate.state = "input-available"; if (part.providerMetadata) { const updatable = toUpdate; updatable.callProviderMetadata = mergeProviderMetadata(updatable.callProviderMetadata, part.providerMetadata); } } break; case "tool-call": { let newPart; if (part.dynamic) { newPart = { type: "dynamic-tool", toolCallId: part.toolCallId, toolName: part.toolName, input: part.input, state: "input-available", }; } else { newPart = { type: `tool-${part.toolName}`, toolCallId: part.toolCallId, input: part.input, state: "input-available", }; if (part.providerExecuted) { newPart.providerExecuted = part.providerExecuted; } } if (part.providerMetadata) { newPart.callProviderMetadata = part.providerMetadata; } if (toolPartsById.has(part.toolCallId)) { const toUpdate = toolPartsById.get(part.toolCallId); Object.assign(toUpdate, newPart); } else { toolPartsById.set(part.toolCallId, newPart); message.parts.push(newPart); } break; } case "tool-result": { const toolCall = toolPartsById.get(part.toolCallId); assert(toolCall, `Expected to find tool call part ${part.toolCallId} to update with result`); let newPart; if (toolCall.type === "dynamic-tool") { newPart = { ...toolCall, state: "output-available", input: part.input ?? toolCall.input, output: part.output ?? toolCall.output, ...pick(part, ["preliminary"]), }; } else { newPart = { ...toolCall, state: "output-available", input: part.input ?? toolCall.input, output: part.output ?? toolCall.output, preliminary: part.preliminary, }; } Object.assign(toolCall, newPart); break; } case "reasoning-start": case "reasoning-delta": { if (!reasoningPartsById.has(part.id)) { const lastPart = message.parts.at(-1); if (lastPart?.type === "reasoning") { reasoningPartsById.set(part.id, lastPart); } else { const newPart = { type: "reasoning", state: "streaming", text: "", providerMetadata: part.providerMetadata, }; reasoningPartsById.set(part.id, newPart); message.parts.push(newPart); } } const reasoningPart = reasoningPartsById.get(part.id); if (part.type === "reasoning-delta") { reasoningPart.text += part.text; reasoningPart.providerMetadata = mergeProviderMetadata(reasoningPart.providerMetadata, part.providerMetadata); } break; } case "reasoning-end": { const reasoningPart = reasoningPartsById.get(part.id) ?? message.parts.find((p) => p.type === "reasoning" && p.state === "streaming"); if (reasoningPart) { reasoningPart.state = "done"; } else { console.warn(`Expected to find reasoning part ${part.id} to finish, but found none`); } break; } case "source": if (part.sourceType === "url") { message.parts.push({ type: "source-url", url: part.url, sourceId: part.id, providerMetadata: part.providerMetadata, title: part.title, }); } else if (part.sourceType === "document") { message.parts.push({ type: "source-document", mediaType: part.mediaType, sourceId: part.id, title: part.title, filename: part.filename, providerMetadata: part.providerMetadata, }); } else { console.warn("Got source part with unknown source type", part); } break; case "abort": message.status = "failed"; break; case "error": message.status = "failed"; console.warn("Generation failed with error", part.error); break; case "tool-error": { const toolPart = toolPartsById.get(part.toolCallId); if (toolPart) { toolPart.errorText = getErrorMessage(part.error); } break; } case "file": case "text-end": case "finish-step": case "finish": case "raw": case "start-step": case "start": // ignore break; default: { // Should never happen const _ = part; console.warn(`Received unexpected part: ${JSON.stringify(part)}`); break; } } } // Consider reasoning done once something else happens for (let i = 0; i < message.parts.length - 1; i++) { const part = message.parts[i]; if (part.type === "reasoning") { part.state = "done"; } } message.text = joinText(message.parts); return [ { streamId: streamMessage.streamId, cursor, message, }, true, ]; } function mergeProviderMetadata(existing, part) { if (!existing && !part) { return undefined; } if (!existing) { return part; } if (!part) { return existing; } const merged = existing; for (const [provider, metadata] of Object.entries(part)) { merged[provider] = { ...merged[provider], ...metadata, }; } return merged; } //# sourceMappingURL=deltas.js.map