UNPKG

@convex-dev/agent

Version:

A agent component for Convex.

532 lines (518 loc) 15.3 kB
import { readUIMessageStream, type DynamicToolUIPart, type ProviderMetadata, type ReasoningUIPart, type TextStreamPart, type TextUIPart, type ToolSet, type ToolUIPart, type UIMessageChunk, } from "ai"; import { assert, pick } from "convex-helpers"; import { type UIMessage } from "./UIMessages.js"; import { joinText, sorted } from "./shared.js"; import { type MessageStatus, type StreamDelta, type StreamMessage, } from "./validators.js"; import { getErrorMessage } from "@ai-sdk/provider-utils"; export function blankUIMessage<METADATA = unknown>( streamMessage: StreamMessage & { metadata?: METADATA }, threadId: string, ): UIMessage<METADATA> { return { id: `stream:${streamMessage.streamId}`, key: `${threadId}-${streamMessage.order}-${streamMessage.stepOrder}`, order: streamMessage.order, stepOrder: streamMessage.stepOrder, status: statusFromStreamStatus(streamMessage.status), agentName: streamMessage.agentName, text: "", _creationTime: Date.now(), role: "assistant", parts: [], ...(streamMessage.metadata ? { metadata: streamMessage.metadata } : {}), }; } export function statusFromStreamStatus( status: StreamMessage["status"], ): MessageStatus | "streaming" { switch (status) { case "streaming": return "streaming"; case "finished": return "success"; case "aborted": return "failed"; default: return "pending"; } } export async function updateFromUIMessageChunks( uiMessage: UIMessage, parts: UIMessageChunk[], ) { const partsStream = new ReadableStream<UIMessageChunk>({ start(controller) { for (const part of parts) { controller.enqueue(part); } controller.close(); }, }); let failed = false; const messageStream = readUIMessageStream({ message: uiMessage, stream: partsStream, onError: (e) => { failed = true; console.error("Error in stream", e); }, terminateOnError: true, }); let message = uiMessage; for await (const messagePart of messageStream) { assert( messagePart.id === message.id, `Expecting to only make one UIMessage in a stream`, ); message = messagePart; } if (failed) { message.status = "failed"; } message.text = joinText(message.parts); return message; } export async function deriveUIMessagesFromDeltas( threadId: string, streamMessages: StreamMessage[], allDeltas: StreamDelta[], ): Promise<UIMessage[]> { const messages: UIMessage[] = []; for (const streamMessage of streamMessages) { if (streamMessage.format === "UIMessageChunk") { const { parts } = getParts<UIMessageChunk>( allDeltas.filter((d) => d.streamId === streamMessage.streamId), 0, ); const uiMessage = await updateFromUIMessageChunks( blankUIMessage(streamMessage, threadId), parts, ); // TODO: this fails on partial tool calls messages.push(uiMessage); } else { const [uiMessages] = deriveUIMessagesFromTextStreamParts( threadId, [streamMessage], [], allDeltas, ); messages.push(...uiMessages); } } return sorted(messages); } /** * */ export function deriveUIMessagesFromTextStreamParts( threadId: string, streamMessages: StreamMessage[], existingStreams: Array<{ streamId: string; cursor: number; message: UIMessage; }>, allDeltas: StreamDelta[], ): [ UIMessage[], Array<{ streamId: string; cursor: number; message: UIMessage }>, boolean, ] { const newStreams: Array<{ streamId: string; cursor: number; message: UIMessage; }> = []; // Seed the existing chunks let changed = false; for (const streamMessage of streamMessages) { const deltas = allDeltas.filter( (d) => d.streamId === streamMessage.streamId, ); const existing = existingStreams.find( (s) => s.streamId === streamMessage.streamId, ); const [newStream, messageChanged] = updateFromTextStreamParts( threadId, streamMessage, existing, deltas, ); newStreams.push(newStream); if (messageChanged) changed = true; } for (const { streamId } of existingStreams) { if (!newStreams.find((s) => s.streamId === streamId)) { // There's a stream that's no longer active. changed = true; } } const messages = sorted(newStreams.map((s) => s.message)); return [messages, newStreams, changed]; } export function getParts<T extends StreamDelta["parts"][number]>( deltas: StreamDelta[], fromCursor?: number, ): { parts: T[]; cursor: number } { const parts: T[] = []; let cursor = fromCursor ?? 0; for (const delta of deltas.sort((a, b) => a.start - b.start)) { if (delta.parts.length === 0) { console.debug(`Got delta with no parts: ${JSON.stringify(delta)}`); continue; } if (cursor !== delta.start) { if (cursor >= delta.end) { continue; } else if (cursor < delta.start) { console.warn( `Got delta for stream ${delta.streamId} that has a gap ${cursor} -> ${delta.start}`, ); break; } else { throw new Error( `Got unexpected delta for stream ${delta.streamId}: delta: ${delta.start} -> ${delta.end} existing cursor: ${cursor}`, ); } } parts.push(...delta.parts); cursor = delta.end; } return { parts, cursor }; } /** * This is historically from when we would use the onChunk callback instead of * consuming the full UIMessageStream. */ // exported for testing export function updateFromTextStreamParts( threadId: string, streamMessage: StreamMessage, existing: | { streamId: string; cursor: number; message: UIMessage } | undefined, deltas: StreamDelta[], ): [{ streamId: string; cursor: number; message: UIMessage }, boolean] { const { cursor, parts } = getParts<TextStreamPart<ToolSet>>( deltas, existing?.cursor, ); const changed = parts.length > 0 || (existing && statusFromStreamStatus(streamMessage.status) !== existing.message.status); const existingMessage = existing?.message ?? blankUIMessage(streamMessage, threadId); if (!changed) { return [ existing ?? { streamId: streamMessage.streamId, cursor, message: existingMessage, }, false, ]; } const message: UIMessage = structuredClone(existingMessage); message.status = statusFromStreamStatus(streamMessage.status); const textPartsById = new Map<string, TextUIPart>(); const toolPartsById = new Map<string, ToolUIPart | DynamicToolUIPart>( message.parts .filter( (p): p is ToolUIPart | DynamicToolUIPart => p.type.startsWith("tool-") || p.type === "dynamic-tool", ) .map((p) => [p.toolCallId, p]), ); const reasoningPartsById = new Map<string, ReasoningUIPart>(); for (const part of parts) { switch (part.type) { case "text-start": case "text-delta": { if (!textPartsById.has(part.id)) { const lastPart = message.parts.at(-1); if (lastPart?.type === "text") { textPartsById.set(part.id, lastPart); } else { const newPart = { type: "text", text: "", providerMetadata: part.providerMetadata, } satisfies TextUIPart; textPartsById.set(part.id, newPart); message.parts.push(newPart); } } if (part.type === "text-delta") { const textPart = textPartsById.get(part.id)!; textPart.text += part.text; textPart.providerMetadata = mergeProviderMetadata( textPart.providerMetadata, part.providerMetadata, ); } break; } case "tool-input-start": { let newPart: ToolUIPart | DynamicToolUIPart; if (part.dynamic) { newPart = { type: "dynamic-tool", toolCallId: part.id, toolName: part.toolName, state: "input-streaming", input: "", } satisfies DynamicToolUIPart; } else { newPart = { type: `tool-${part.toolName}`, toolCallId: part.id, state: "input-streaming", input: "", providerExecuted: part.providerExecuted, } satisfies ToolUIPart; } toolPartsById.set(part.id, newPart); message.parts.push(newPart); break; } case "tool-input-delta": { const toUpdate = toolPartsById.get(part.id); assert( toUpdate, `Expected to find tool call part ${part.id} to update`, ); toUpdate.input = (toUpdate.input ?? "") + part.delta; } break; case "tool-input-end": { const toUpdate = toolPartsById.get(part.id); assert( toUpdate, `Expected to find tool call part ${part.id} to update`, ); toUpdate.state = "input-available"; if (part.providerMetadata) { const updatable = toUpdate as Extract< ToolUIPart | DynamicToolUIPart, { state: "input-available" } >; updatable.callProviderMetadata = mergeProviderMetadata( updatable.callProviderMetadata, part.providerMetadata, ); } } break; case "tool-call": { let newPart: ToolUIPart | DynamicToolUIPart; if (part.dynamic) { newPart = { type: "dynamic-tool", toolCallId: part.toolCallId, toolName: part.toolName, input: part.input, state: "input-available", }; } else { newPart = { type: `tool-${part.toolName}`, toolCallId: part.toolCallId, input: part.input, state: "input-available", }; if (part.providerExecuted) { newPart.providerExecuted = part.providerExecuted; } } if (part.providerMetadata) { newPart.callProviderMetadata = part.providerMetadata; } if (toolPartsById.has(part.toolCallId)) { const toUpdate = toolPartsById.get(part.toolCallId)!; Object.assign(toUpdate, newPart); } else { toolPartsById.set(part.toolCallId, newPart); message.parts.push(newPart); } break; } case "tool-result": { const toolCall = toolPartsById.get(part.toolCallId); assert( toolCall, `Expected to find tool call part ${part.toolCallId} to update with result`, ); let newPart: ToolUIPart | DynamicToolUIPart; if (toolCall.type === "dynamic-tool") { newPart = { ...toolCall, state: "output-available", input: part.input ?? toolCall.input, output: part.output ?? toolCall.output, ...pick(part, ["preliminary"]), } as DynamicToolUIPart; } else { newPart = { ...toolCall, state: "output-available", input: part.input ?? toolCall.input, output: part.output ?? toolCall.output, preliminary: part.preliminary, } as ToolUIPart; } Object.assign(toolCall, newPart); break; } case "reasoning-start": case "reasoning-delta": { if (!reasoningPartsById.has(part.id)) { const lastPart = message.parts.at(-1); if (lastPart?.type === "reasoning") { reasoningPartsById.set(part.id, lastPart); } else { const newPart = { type: "reasoning", state: "streaming", text: "", providerMetadata: part.providerMetadata, } satisfies ReasoningUIPart; reasoningPartsById.set(part.id, newPart); message.parts.push(newPart); } } const reasoningPart = reasoningPartsById.get(part.id)!; if (part.type === "reasoning-delta") { reasoningPart.text += part.text; reasoningPart.providerMetadata = mergeProviderMetadata( reasoningPart.providerMetadata, part.providerMetadata, ); } break; } case "reasoning-end": { const reasoningPart = reasoningPartsById.get(part.id) ?? message.parts.find( (p): p is ReasoningUIPart => p.type === "reasoning" && p.state === "streaming", )!; if (reasoningPart) { reasoningPart.state = "done"; } else { console.warn( `Expected to find reasoning part ${part.id} to finish, but found none`, ); } break; } case "source": if (part.sourceType === "url") { message.parts.push({ type: "source-url", url: part.url, sourceId: part.id, providerMetadata: part.providerMetadata, title: part.title, }); } else if (part.sourceType === "document") { message.parts.push({ type: "source-document", mediaType: part.mediaType, sourceId: part.id, title: part.title, filename: part.filename, providerMetadata: part.providerMetadata, }); } else { console.warn("Got source part with unknown source type", part); } break; case "abort": message.status = "failed"; break; case "error": message.status = "failed"; console.warn("Generation failed with error", part.error); break; case "tool-error": { const toolPart = toolPartsById.get(part.toolCallId); if (toolPart) { toolPart.errorText = getErrorMessage(part.error); } break; } case "file": case "text-end": case "finish-step": case "finish": case "raw": case "start-step": case "start": // ignore break; default: { // Should never happen const _: never = part; console.warn(`Received unexpected part: ${JSON.stringify(part)}`); break; } } } // Consider reasoning done once something else happens for (let i = 0; i < message.parts.length - 1; i++) { const part = message.parts[i]; if (part.type === "reasoning") { part.state = "done"; } } message.text = joinText(message.parts); return [ { streamId: streamMessage.streamId, cursor, message, }, true, ]; } function mergeProviderMetadata( existing: ProviderMetadata | undefined, part: ProviderMetadata | undefined, ): ProviderMetadata | undefined { if (!existing && !part) { return undefined; } if (!existing) { return part; } if (!part) { return existing; } const merged: ProviderMetadata = existing; for (const [provider, metadata] of Object.entries(part)) { merged[provider] = { ...merged[provider], ...metadata, }; } return merged; }