UNPKG

@tanstack/ai

Version:

Core TanStack AI library - Open source AI SDK

879 lines (878 loc) 27.5 kB
import { devtoolsMiddleware } from "@tanstack/ai-event-client"; import { streamToText } from "../../stream-to-response.js"; import { LazyToolManager } from "./tools/lazy-tool-manager.js"; import { ToolCallManager, MiddlewareAbortError, executeToolCalls } from "./tools/tool-calls.js"; import { convertSchemaToJsonSchema, isStandardSchema, parseWithStandardSchema } from "./tools/schema-converter.js"; import { maxIterations } from "./agent-loop-strategies.js"; import { convertMessagesToModelMessages } from "./messages.js"; import { MiddlewareRunner } from "./middleware/compose.js"; const kind = "text"; function createChatOptions(options) { return options; } class TextEngine { constructor(config) { this.iterationCount = 0; this.lastFinishReason = null; this.streamStartTime = 0; this.totalChunkCount = 0; this.currentMessageId = null; this.accumulatedContent = ""; this.finishedEvent = null; this.earlyTermination = false; this.toolPhase = "continue"; this.cyclePhase = "processText"; this.deferredPromises = []; this.terminalHookCalled = false; this.adapter = config.adapter; this.params = config.params; this.systemPrompts = config.params.systemPrompts || []; this.loopStrategy = config.params.agentLoopStrategy || maxIterations(5); this.initialMessageCount = config.params.messages.length; const { approvals, clientToolResults } = this.extractClientStateFromOriginalMessages( config.params.messages ); this.initialApprovals = approvals; this.initialClientToolResults = clientToolResults; this.messages = convertMessagesToModelMessages( config.params.messages ); this.lazyToolManager = new LazyToolManager( config.params.tools || [], this.messages ); this.tools = this.lazyToolManager.getActiveTools(); this.toolCallManager = new ToolCallManager(this.tools); this.requestId = this.createId("chat"); this.streamId = this.createId("stream"); this.effectiveRequest = config.params.abortController ? { signal: config.params.abortController.signal } : void 0; this.effectiveSignal = config.params.abortController?.signal; const allMiddleware = [devtoolsMiddleware(), ...config.middleware || []]; this.middlewareRunner = new MiddlewareRunner(allMiddleware); this.middlewareAbortController = new AbortController(); this.middlewareCtx = { requestId: this.requestId, streamId: this.streamId, conversationId: config.params.conversationId, phase: "init", iteration: 0, chunkIndex: 0, signal: this.effectiveSignal, abort: (reason) => { this.abortReason = reason; this.middlewareAbortController?.abort(reason); }, context: config.context, defer: (promise) => { this.deferredPromises.push(promise); }, // Provider / adapter info provider: config.adapter.name, model: config.params.model, source: "server", streaming: true, // Config-derived (updated in beforeRun and applyMiddlewareConfig) systemPrompts: this.systemPrompts, toolNames: void 0, options: void 0, modelOptions: config.params.modelOptions, // Computed messageCount: this.initialMessageCount, hasTools: this.tools.length > 0, // Mutable per-iteration currentMessageId: null, accumulatedContent: "", // References messages: this.messages, createId: (prefix) => this.createId(prefix) }; } /** Get the accumulated content after the chat loop completes */ getAccumulatedContent() { return this.accumulatedContent; } /** Get the final messages array after the chat loop completes */ getMessages() { return this.messages; } async *run() { this.beforeRun(); try { this.middlewareCtx.phase = "init"; const initialConfig = this.buildMiddlewareConfig(); const transformedConfig = await this.middlewareRunner.runOnConfig( this.middlewareCtx, initialConfig ); this.applyMiddlewareConfig(transformedConfig); await this.middlewareRunner.runOnStart(this.middlewareCtx); const pendingPhase = yield* this.checkForPendingToolCalls(); if (pendingPhase === "wait") { return; } do { if (this.earlyTermination || this.isCancelled()) { return; } await this.beginCycle(); if (this.cyclePhase === "processText") { this.middlewareCtx.phase = "beforeModel"; this.middlewareCtx.iteration = this.iterationCount; const iterConfig = this.buildMiddlewareConfig(); const transformedConfig2 = await this.middlewareRunner.runOnConfig( this.middlewareCtx, iterConfig ); this.applyMiddlewareConfig(transformedConfig2); yield* this.streamModelResponse(); } else { yield* this.processToolCalls(); } this.endCycle(); } while (this.shouldContinue()); if (!this.terminalHookCalled && this.toolPhase !== "wait") { this.terminalHookCalled = true; await this.middlewareRunner.runOnFinish(this.middlewareCtx, { finishReason: this.lastFinishReason, duration: Date.now() - this.streamStartTime, content: this.accumulatedContent, usage: this.finishedEvent?.usage }); } } catch (error) { if (!this.terminalHookCalled) { this.terminalHookCalled = true; if (error instanceof MiddlewareAbortError) { this.abortReason = error.message; await this.middlewareRunner.runOnAbort(this.middlewareCtx, { reason: error.message, duration: Date.now() - this.streamStartTime }); } else { await this.middlewareRunner.runOnError(this.middlewareCtx, { error, duration: Date.now() - this.streamStartTime }); } } if (!(error instanceof MiddlewareAbortError)) { throw error; } } finally { if (!this.terminalHookCalled && this.isCancelled()) { this.terminalHookCalled = true; await this.middlewareRunner.runOnAbort(this.middlewareCtx, { reason: this.abortReason, duration: Date.now() - this.streamStartTime }); } if (this.deferredPromises.length > 0) { await Promise.allSettled(this.deferredPromises); } } } beforeRun() { this.streamStartTime = Date.now(); const { tools, temperature, topP, maxTokens, metadata } = this.params; const options = {}; if (temperature !== void 0) options.temperature = temperature; if (topP !== void 0) options.topP = topP; if (maxTokens !== void 0) options.maxTokens = maxTokens; if (metadata !== void 0) options.metadata = metadata; this.eventOptions = Object.keys(options).length > 0 ? options : void 0; this.eventToolNames = tools?.map((t) => t.name); this.middlewareCtx.options = this.eventOptions; this.middlewareCtx.toolNames = this.eventToolNames; } async beginCycle() { if (this.cyclePhase === "processText") { await this.beginIteration(); } } endCycle() { if (this.cyclePhase === "processText") { this.cyclePhase = "executeToolCalls"; return; } this.cyclePhase = "processText"; this.iterationCount++; } async beginIteration() { this.currentMessageId = this.createId("msg"); this.accumulatedContent = ""; this.finishedEvent = null; this.middlewareCtx.currentMessageId = this.currentMessageId; this.middlewareCtx.accumulatedContent = ""; await this.middlewareRunner.runOnIteration(this.middlewareCtx, { iteration: this.iterationCount, messageId: this.currentMessageId }); } async *streamModelResponse() { const { temperature, topP, maxTokens, metadata, modelOptions } = this.params; const tools = this.tools; const toolsWithJsonSchemas = tools.map((tool) => ({ ...tool, inputSchema: tool.inputSchema ? convertSchemaToJsonSchema(tool.inputSchema) : void 0, outputSchema: tool.outputSchema ? convertSchemaToJsonSchema(tool.outputSchema) : void 0 })); this.middlewareCtx.phase = "modelStream"; for await (const chunk of this.adapter.chatStream({ model: this.params.model, messages: this.messages, tools: toolsWithJsonSchemas, temperature, topP, maxTokens, metadata, request: this.effectiveRequest, modelOptions, systemPrompts: this.systemPrompts })) { if (this.isCancelled()) { break; } this.totalChunkCount++; const outputChunks = await this.middlewareRunner.runOnChunk( this.middlewareCtx, chunk ); for (const outputChunk of outputChunks) { yield outputChunk; this.handleStreamChunk(outputChunk); this.middlewareCtx.chunkIndex++; } if (chunk.type === "RUN_FINISHED" && chunk.usage) { await this.middlewareRunner.runOnUsage(this.middlewareCtx, chunk.usage); } if (this.earlyTermination) { break; } } } handleStreamChunk(chunk) { switch (chunk.type) { // AG-UI Events case "TEXT_MESSAGE_CONTENT": this.handleTextMessageContentEvent(chunk); break; case "TOOL_CALL_START": this.handleToolCallStartEvent(chunk); break; case "TOOL_CALL_ARGS": this.handleToolCallArgsEvent(chunk); break; case "TOOL_CALL_END": this.handleToolCallEndEvent(chunk); break; case "RUN_FINISHED": this.handleRunFinishedEvent(chunk); break; case "RUN_ERROR": this.handleRunErrorEvent(chunk); break; case "STEP_FINISHED": this.handleStepFinishedEvent(chunk); break; } } // =========================== // AG-UI Event Handlers // =========================== handleTextMessageContentEvent(chunk) { if (chunk.content) { this.accumulatedContent = chunk.content; } else { this.accumulatedContent += chunk.delta; } this.middlewareCtx.accumulatedContent = this.accumulatedContent; } handleToolCallStartEvent(chunk) { this.toolCallManager.addToolCallStartEvent(chunk); } handleToolCallArgsEvent(chunk) { this.toolCallManager.addToolCallArgsEvent(chunk); } handleToolCallEndEvent(chunk) { this.toolCallManager.completeToolCall(chunk); } handleRunFinishedEvent(chunk) { this.finishedEvent = chunk; this.lastFinishReason = chunk.finishReason; } handleRunErrorEvent(_chunk) { this.earlyTermination = true; } handleStepFinishedEvent(_chunk) { } async *checkForPendingToolCalls() { const pendingToolCalls = this.getPendingToolCallsFromMessages(); if (pendingToolCalls.length === 0) { return "continue"; } const finishEvent = this.createSyntheticFinishedEvent(); const undiscoveredLazyResults = []; const executablePendingCalls = pendingToolCalls.filter((tc) => { if (this.lazyToolManager.isUndiscoveredLazyTool(tc.function.name)) { undiscoveredLazyResults.push({ toolCallId: tc.id, toolName: tc.function.name, result: { error: this.lazyToolManager.getUndiscoveredToolError( tc.function.name ) }, state: "output-error" }); return false; } return true; }); if (undiscoveredLazyResults.length > 0) { for (const chunk of this.buildToolResultChunks( undiscoveredLazyResults, finishEvent )) { yield chunk; } } if (executablePendingCalls.length === 0) { return "continue"; } const { approvals, clientToolResults } = this.collectClientState(); const generator = executeToolCalls( executablePendingCalls, this.tools, approvals, clientToolResults, (eventName, data) => this.createCustomEventChunk(eventName, data), { onBeforeToolCall: async (toolCall, tool, args) => { const hookCtx = { toolCall, tool, args, toolName: toolCall.function.name, toolCallId: toolCall.id }; return this.middlewareRunner.runOnBeforeToolCall( this.middlewareCtx, hookCtx ); }, onAfterToolCall: async (info) => { await this.middlewareRunner.runOnAfterToolCall( this.middlewareCtx, info ); } } ); const executionResult = yield* this.drainToolCallGenerator(generator); if (this.isMiddlewareAborted()) { this.setToolPhase("stop"); return "stop"; } await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, { toolCalls: pendingToolCalls, results: executionResult.results, needsApproval: executionResult.needsApproval, needsClientExecution: executionResult.needsClientExecution }); const argsMap = /* @__PURE__ */ new Map(); for (const tc of pendingToolCalls) { argsMap.set(tc.id, tc.function.arguments); } if (executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0) { if (executionResult.results.length > 0) { for (const chunk of this.buildToolResultChunks( executionResult.results, finishEvent, argsMap )) { yield chunk; } } for (const chunk of this.buildApprovalChunks( executionResult.needsApproval, finishEvent )) { yield chunk; } for (const chunk of this.buildClientToolChunks( executionResult.needsClientExecution, finishEvent )) { yield chunk; } this.setToolPhase("wait"); return "wait"; } const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent, argsMap ); for (const chunk of toolResultChunks) { yield chunk; } return "continue"; } async *processToolCalls() { if (!this.shouldExecuteToolPhase()) { this.setToolPhase("stop"); return; } const toolCalls = this.toolCallManager.getToolCalls(); const finishEvent = this.finishedEvent; if (!finishEvent || toolCalls.length === 0) { this.setToolPhase("stop"); return; } this.addAssistantToolCallMessage(toolCalls); const undiscoveredLazyResults = []; const executableToolCalls = toolCalls.filter((tc) => { if (this.lazyToolManager.isUndiscoveredLazyTool(tc.function.name)) { undiscoveredLazyResults.push({ toolCallId: tc.id, toolName: tc.function.name, result: { error: this.lazyToolManager.getUndiscoveredToolError( tc.function.name ) }, state: "output-error" }); return false; } return true; }); if (undiscoveredLazyResults.length > 0) { const finishEvt = this.finishedEvent; for (const chunk of this.buildToolResultChunks( undiscoveredLazyResults, finishEvt )) { yield chunk; } } if (executableToolCalls.length === 0) { this.toolCallManager.clear(); this.setToolPhase("continue"); return; } this.middlewareCtx.phase = "beforeTools"; const { approvals, clientToolResults } = this.collectClientState(); const generator = executeToolCalls( executableToolCalls, this.tools, approvals, clientToolResults, (eventName, data) => this.createCustomEventChunk(eventName, data), { onBeforeToolCall: async (toolCall, tool, args) => { const hookCtx = { toolCall, tool, args, toolName: toolCall.function.name, toolCallId: toolCall.id }; return this.middlewareRunner.runOnBeforeToolCall( this.middlewareCtx, hookCtx ); }, onAfterToolCall: async (info) => { await this.middlewareRunner.runOnAfterToolCall( this.middlewareCtx, info ); } } ); const executionResult = yield* this.drainToolCallGenerator(generator); this.middlewareCtx.phase = "afterTools"; if (this.isMiddlewareAborted()) { this.setToolPhase("stop"); return; } await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, { toolCalls, results: executionResult.results, needsApproval: executionResult.needsApproval, needsClientExecution: executionResult.needsClientExecution }); if (executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0) { if (executionResult.results.length > 0) { for (const chunk of this.buildToolResultChunks( executionResult.results, finishEvent )) { yield chunk; } } for (const chunk of this.buildApprovalChunks( executionResult.needsApproval, finishEvent )) { yield chunk; } for (const chunk of this.buildClientToolChunks( executionResult.needsClientExecution, finishEvent )) { yield chunk; } this.setToolPhase("wait"); return; } const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent ); for (const chunk of toolResultChunks) { yield chunk; } if (this.lazyToolManager.hasNewlyDiscoveredTools()) { this.tools = this.lazyToolManager.getActiveTools(); this.toolCallManager = new ToolCallManager(this.tools); this.setToolPhase("continue"); return; } this.toolCallManager.clear(); this.setToolPhase("continue"); } shouldExecuteToolPhase() { return this.finishedEvent?.finishReason === "tool_calls" && this.tools.length > 0 && this.toolCallManager.hasToolCalls(); } addAssistantToolCallMessage(toolCalls) { this.messages = [ ...this.messages, { role: "assistant", content: this.accumulatedContent || null, toolCalls } ]; } /** * Extract client state (approvals and client tool results) from original messages. * This is called in the constructor BEFORE converting to ModelMessage format, * because the parts array (which contains approval state) is lost during conversion. */ extractClientStateFromOriginalMessages(originalMessages) { const approvals = /* @__PURE__ */ new Map(); const clientToolResults = /* @__PURE__ */ new Map(); for (const message of originalMessages) { if (message.role === "assistant" && message.parts) { for (const part of message.parts) { if (part.type === "tool-call") { if (part.output !== void 0 && !part.approval) { clientToolResults.set(part.id, part.output); } if (part.approval?.id && part.approval?.approved !== void 0 && part.state === "approval-responded") { approvals.set(part.approval.id, part.approval.approved); } } } } } return { approvals, clientToolResults }; } collectClientState() { const approvals = new Map(this.initialApprovals); const clientToolResults = new Map(this.initialClientToolResults); for (const message of this.messages) { if (message.role === "tool" && message.toolCallId) { let output; try { output = JSON.parse(message.content); } catch { output = message.content; } if (output && typeof output === "object" && output.pendingExecution === true) { continue; } clientToolResults.set(message.toolCallId, output); } } return { approvals, clientToolResults }; } buildApprovalChunks(approvals, finishEvent) { const chunks = []; for (const approval of approvals) { chunks.push({ type: "CUSTOM", timestamp: Date.now(), model: finishEvent.model, name: "approval-requested", value: { toolCallId: approval.toolCallId, toolName: approval.toolName, input: approval.input, approval: { id: approval.approvalId, needsApproval: true } } }); } return chunks; } buildClientToolChunks(clientRequests, finishEvent) { const chunks = []; for (const clientTool of clientRequests) { chunks.push({ type: "CUSTOM", timestamp: Date.now(), model: finishEvent.model, name: "tool-input-available", value: { toolCallId: clientTool.toolCallId, toolName: clientTool.toolName, input: clientTool.input } }); } return chunks; } buildToolResultChunks(results, finishEvent, argsMap) { const chunks = []; for (const result of results) { const content = JSON.stringify(result.result); if (argsMap) { chunks.push({ type: "TOOL_CALL_START", timestamp: Date.now(), model: finishEvent.model, toolCallId: result.toolCallId, toolName: result.toolName }); const args = argsMap.get(result.toolCallId) ?? "{}"; chunks.push({ type: "TOOL_CALL_ARGS", timestamp: Date.now(), model: finishEvent.model, toolCallId: result.toolCallId, delta: args, args }); } chunks.push({ type: "TOOL_CALL_END", timestamp: Date.now(), model: finishEvent.model, toolCallId: result.toolCallId, toolName: result.toolName, result: content }); this.messages = [ ...this.messages, { role: "tool", content, toolCallId: result.toolCallId } ]; } return chunks; } getPendingToolCallsFromMessages() { const completedToolIds = /* @__PURE__ */ new Set(); for (const message of this.messages) { if (message.role === "tool" && message.toolCallId) { let hasPendingExecution = false; if (typeof message.content === "string") { try { const parsed = JSON.parse(message.content); if (parsed.pendingExecution === true) { hasPendingExecution = true; } } catch { } } if (!hasPendingExecution) { completedToolIds.add(message.toolCallId); } } } const pending = []; for (const message of this.messages) { if (message.role === "assistant" && message.toolCalls) { for (const toolCall of message.toolCalls) { if (!completedToolIds.has(toolCall.id)) { pending.push(toolCall); } } } } return pending; } createSyntheticFinishedEvent() { return { type: "RUN_FINISHED", runId: this.createId("pending"), model: this.params.model, timestamp: Date.now(), finishReason: "tool_calls" }; } shouldContinue() { if (this.cyclePhase === "executeToolCalls") { return true; } return this.loopStrategy({ iterationCount: this.iterationCount, messages: this.messages, finishReason: this.lastFinishReason }) && this.toolPhase === "continue"; } isAborted() { return !!this.effectiveSignal?.aborted; } isMiddlewareAborted() { return !!this.middlewareAbortController?.signal.aborted; } isCancelled() { return this.isAborted() || this.isMiddlewareAborted(); } buildMiddlewareConfig() { return { messages: this.messages, systemPrompts: [...this.systemPrompts], tools: [...this.tools], temperature: this.params.temperature, topP: this.params.topP, maxTokens: this.params.maxTokens, metadata: this.params.metadata, modelOptions: this.params.modelOptions }; } applyMiddlewareConfig(config) { this.messages = config.messages; this.systemPrompts = config.systemPrompts; this.tools = config.tools; this.params = { ...this.params, temperature: config.temperature, topP: config.topP, maxTokens: config.maxTokens, metadata: config.metadata, modelOptions: config.modelOptions }; this.middlewareCtx.messages = this.messages; this.middlewareCtx.systemPrompts = this.systemPrompts; this.middlewareCtx.hasTools = this.tools.length > 0; this.middlewareCtx.toolNames = this.tools.map((t) => t.name); this.middlewareCtx.modelOptions = config.modelOptions; } setToolPhase(phase) { this.toolPhase = phase; } /** * Drain an executeToolCalls async generator, yielding any CustomEvent chunks * and returning the final ExecuteToolCallsResult. */ async *drainToolCallGenerator(generator) { let next = await generator.next(); while (!next.done) { yield next.value; next = await generator.next(); } return next.value; } createCustomEventChunk(eventName, value) { return { type: "CUSTOM", timestamp: Date.now(), model: this.params.model, name: eventName, value }; } createId(prefix) { return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`; } } function chat(options) { const { outputSchema, stream } = options; if (outputSchema) { return runAgenticStructuredOutput( options ); } if (stream === false) { return runNonStreamingText( options ); } return runStreamingText( options ); } async function* runStreamingText(options) { const { adapter, middleware, context, ...textOptions } = options; const model = adapter.model; const engine = new TextEngine({ adapter, params: { ...textOptions, model }, middleware, context }); for await (const chunk of engine.run()) { yield chunk; } } function runNonStreamingText(options) { const stream = runStreamingText( options ); return streamToText(stream); } async function runAgenticStructuredOutput(options) { const { adapter, outputSchema, middleware, context, ...textOptions } = options; const model = adapter.model; if (!outputSchema) { throw new Error("outputSchema is required for structured output"); } const engine = new TextEngine({ adapter, params: { ...textOptions, model }, middleware, context }); for await (const _chunk of engine.run()) { } const finalMessages = engine.getMessages(); const { tools: _tools, agentLoopStrategy: _als, ...structuredTextOptions } = textOptions; const jsonSchema = convertSchemaToJsonSchema(outputSchema); if (!jsonSchema) { throw new Error("Failed to convert output schema to JSON Schema"); } const result = await adapter.structuredOutput({ chatOptions: { ...structuredTextOptions, model, messages: finalMessages }, outputSchema: jsonSchema }); if (isStandardSchema(outputSchema)) { return parseWithStandardSchema( outputSchema, result.data ); } return result.data; } export { chat, createChatOptions, kind }; //# sourceMappingURL=index.js.map