UNPKG

@tanstack/ai

Version:

Core TanStack AI library - Open source AI SDK

736 lines (671 loc) 21.1 kB
import { isStandardSchema, parseWithStandardSchema } from './schema-converter' import type { CustomEvent, ModelMessage, RunFinishedEvent, Tool, ToolCall, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent, ToolExecutionContext, } from '../../../types' import type { AfterToolCallInfo, BeforeToolCallDecision, } from '../middleware/types' function safeJsonParse(value: string): unknown { try { return JSON.parse(value) } catch { return value } } /** * Optional middleware hooks for tool execution. * When provided, these callbacks are invoked before/after each tool execution. */ export interface ToolExecutionMiddlewareHooks { onBeforeToolCall?: ( toolCall: ToolCall, tool: Tool | undefined, args: unknown, ) => Promise<BeforeToolCallDecision> onAfterToolCall?: (info: AfterToolCallInfo) => Promise<void> } /** * Error thrown when middleware decides to abort the chat run during tool execution. */ export class MiddlewareAbortError extends Error { constructor(reason: string) { super(reason) this.name = 'MiddlewareAbortError' } } /** * Manages tool call accumulation and execution for the chat() method's automatic tool execution loop. * * Responsibilities: * - Accumulates streaming tool call events (ID, name, arguments) * - Validates tool calls (filters out incomplete ones) * - Executes tool `execute` functions with parsed arguments * - Emits `TOOL_CALL_END` events for client visibility * - Returns tool result messages for conversation history * * This class is used internally by the AI.chat() method to handle the automatic * tool execution loop. It can also be used independently for custom tool execution logic. * * @example * ```typescript * const manager = new ToolCallManager(tools); * * // During streaming, accumulate tool calls * for await (const chunk of stream) { * if (chunk.type === 'TOOL_CALL_START') { * manager.addToolCallStartEvent(chunk); * } else if (chunk.type === 'TOOL_CALL_ARGS') { * manager.addToolCallArgsEvent(chunk); * } * } * * // After stream completes, execute tools * if (manager.hasToolCalls()) { * const toolResults = yield* manager.executeTools(finishEvent); * messages = [...messages, ...toolResults]; * manager.clear(); * } * ``` */ export class ToolCallManager { private toolCallsMap = new Map<number, ToolCall>() private tools: ReadonlyArray<Tool> constructor(tools: ReadonlyArray<Tool>) { this.tools = tools } /** * Add a TOOL_CALL_START event to begin tracking a tool call (AG-UI) */ addToolCallStartEvent(event: ToolCallStartEvent): void { const index = event.index ?? this.toolCallsMap.size this.toolCallsMap.set(index, { id: event.toolCallId, type: 'function', function: { name: event.toolName, arguments: '', }, ...(event.providerMetadata && { providerMetadata: event.providerMetadata, }), }) } /** * Add a TOOL_CALL_ARGS event to accumulate arguments (AG-UI) */ addToolCallArgsEvent(event: ToolCallArgsEvent): void { // Find the tool call by ID for (const [, toolCall] of this.toolCallsMap.entries()) { if (toolCall.id === event.toolCallId) { toolCall.function.arguments += event.delta break } } } /** * Complete a tool call with its final input * Called when TOOL_CALL_END is received */ completeToolCall(event: ToolCallEndEvent): void { for (const [, toolCall] of this.toolCallsMap.entries()) { if (toolCall.id === event.toolCallId) { if (event.input !== undefined) { toolCall.function.arguments = JSON.stringify(event.input) } break } } } /** * Check if there are any complete tool calls to execute */ hasToolCalls(): boolean { return this.getToolCalls().length > 0 } /** * Get all complete tool calls (filtered for valid ID and name) */ getToolCalls(): Array<ToolCall> { return Array.from(this.toolCallsMap.values()).filter( (tc) => tc.id && tc.function.name && tc.function.name.trim().length > 0, ) } /** * Execute all tool calls and return tool result messages * Yields TOOL_CALL_END events for streaming * @param finishEvent - RUN_FINISHED event from the stream */ async *executeTools( finishEvent: RunFinishedEvent, ): AsyncGenerator<ToolCallEndEvent, Array<ModelMessage>, void> { const toolCallsArray = this.getToolCalls() const toolResults: Array<ModelMessage> = [] for (const toolCall of toolCallsArray) { const tool = this.tools.find((t) => t.name === toolCall.function.name) let toolResultContent: string if (tool?.execute) { try { // Parse arguments (normalize "null" to "{}" for empty tool_use blocks) let args: unknown try { const argsString = toolCall.function.arguments.trim() || '{}' args = JSON.parse(argsString === 'null' ? '{}' : argsString) } catch (parseError) { throw new Error( `Failed to parse tool arguments as JSON: ${toolCall.function.arguments}`, ) } // Validate input against inputSchema (for Standard Schema compliant schemas) if (tool.inputSchema && isStandardSchema(tool.inputSchema)) { try { args = parseWithStandardSchema(tool.inputSchema, args) } catch (validationError: unknown) { const message = validationError instanceof Error ? validationError.message : 'Validation failed' throw new Error( `Input validation failed for tool ${tool.name}: ${message}`, ) } } // Execute the tool let result = await tool.execute(args) // Validate output against outputSchema if provided (for Standard Schema compliant schemas) if ( tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== undefined && result !== null ) { try { result = parseWithStandardSchema(tool.outputSchema, result) } catch (validationError: unknown) { const message = validationError instanceof Error ? validationError.message : 'Validation failed' throw new Error( `Output validation failed for tool ${tool.name}: ${message}`, ) } } toolResultContent = typeof result === 'string' ? result : JSON.stringify(result) } catch (error: unknown) { // If tool execution fails, add error message const message = error instanceof Error ? error.message : 'Unknown error' toolResultContent = `Error executing tool: ${message}` } } else { // Tool doesn't have execute function, add placeholder toolResultContent = `Tool ${toolCall.function.name} does not have an execute function` } // Emit TOOL_CALL_END event yield { type: 'TOOL_CALL_END', toolCallId: toolCall.id, toolName: toolCall.function.name, model: finishEvent.model, timestamp: Date.now(), result: toolResultContent, } // Add tool result message toolResults.push({ role: 'tool', content: toolResultContent, toolCallId: toolCall.id, }) } return toolResults } /** * Clear the tool calls map for the next iteration */ clear(): void { this.toolCallsMap.clear() } } export interface ToolResult { toolCallId: string toolName: string result: any state?: 'output-available' | 'output-error' /** Duration of tool execution in milliseconds (only for server-executed tools) */ duration?: number } export interface ApprovalRequest { toolCallId: string toolName: string input: any approvalId: string } export interface ClientToolRequest { toolCallId: string toolName: string input: any } interface ExecuteToolCallsResult { /** Tool results ready to send to LLM */ results: Array<ToolResult> /** Tools that need user approval before execution */ needsApproval: Array<ApprovalRequest> /** Tools that need client-side execution */ needsClientExecution: Array<ClientToolRequest> } /** * Helper that runs a tool execution promise while polling for pending custom events. * Yields any custom events that are emitted during execution, then returns the * execution result. */ async function* executeWithEventPolling<T>( executionPromise: Promise<T>, pendingEvents: Array<CustomEvent>, ): AsyncGenerator<CustomEvent, T, void> { // Use an object to track mutable state across the async boundary const state = { done: false, result: undefined as T } const executionWithFlag = executionPromise.then((r) => { state.done = true state.result = r return r }) while (!state.done) { // Wait for either the execution to complete or a short timeout await Promise.race([ executionWithFlag, new Promise((resolve) => setTimeout(resolve, 10)), ]) // Flush any pending events while (pendingEvents.length > 0) { yield pendingEvents.shift()! } } // Final flush in case events were emitted right at completion while (pendingEvents.length > 0) { yield pendingEvents.shift()! } return state.result } /** * Apply a middleware onBeforeToolCall decision. * Returns the (possibly transformed) input if execution should proceed, * or undefined if the tool call was skipped (result already pushed). * Throws MiddlewareAbortError if the decision is 'abort'. */ async function applyBeforeToolCallDecision( toolCall: ToolCall, tool: Tool, input: unknown, toolName: string, middlewareHooks: ToolExecutionMiddlewareHooks, results: Array<ToolResult>, ): Promise<{ proceed: true; input: unknown } | { proceed: false }> { if (!middlewareHooks.onBeforeToolCall) { return { proceed: true, input } } const decision = await middlewareHooks.onBeforeToolCall(toolCall, tool, input) if (!decision) { return { proceed: true, input } } if (decision.type === 'abort') { throw new MiddlewareAbortError(decision.reason || 'Aborted by middleware') } if (decision.type === 'skip') { const skipResult = decision.result results.push({ toolCallId: toolCall.id, toolName, result: typeof skipResult === 'string' ? safeJsonParse(skipResult) : skipResult || null, duration: 0, }) if (middlewareHooks.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: true, duration: 0, result: skipResult, }) } return { proceed: false } } return { proceed: true, input: decision.args } } /** * Execute a server-side tool with event polling, output validation, and middleware hooks. * Yields CustomEvent chunks during execution and pushes the result to the results array. */ async function* executeServerTool( toolCall: ToolCall, tool: Tool, toolName: string, input: unknown, context: ToolExecutionContext, pendingEvents: Array<CustomEvent>, results: Array<ToolResult>, middlewareHooks?: ToolExecutionMiddlewareHooks, ): AsyncGenerator<CustomEvent, void, void> { const startTime = Date.now() try { const executionPromise = Promise.resolve(tool.execute!(input, context)) let result = yield* executeWithEventPolling(executionPromise, pendingEvents) const duration = Date.now() - startTime // Flush remaining events while (pendingEvents.length > 0) { yield pendingEvents.shift()! } // Validate output against outputSchema if provided if ( tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== undefined && result !== null ) { result = parseWithStandardSchema(tool.outputSchema, result) } const finalResult = typeof result === 'string' ? safeJsonParse(result) : result || null results.push({ toolCallId: toolCall.id, toolName, result: finalResult, duration, }) if (middlewareHooks?.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: true, duration, result: finalResult, }) } } catch (error: unknown) { const duration = Date.now() - startTime // Flush remaining events while (pendingEvents.length > 0) { yield pendingEvents.shift()! } if (error instanceof MiddlewareAbortError) { throw error } const message = error instanceof Error ? error.message : 'Unknown error' results.push({ toolCallId: toolCall.id, toolName, result: { error: message }, state: 'output-error', duration, }) if (middlewareHooks?.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: false, duration, error, }) } } } /** * Execute tool calls based on their configuration. * Yields CustomEvent chunks during tool execution for real-time progress updates. * * Handles three cases: * 1. Client tools (no execute) - request client to execute * 2. Server tools with approval - check approval before executing * 3. Normal server tools - execute immediately * * @param toolCalls - Tool calls from the LLM * @param tools - Available tools with their configurations * @param approvals - Map of approval decisions (approval.id -> approved boolean) * @param clientResults - Map of client-side execution results (toolCallId -> result) * @param createCustomEventChunk - Factory to create CustomEvent chunks (optional) */ export async function* executeToolCalls( toolCalls: Array<ToolCall>, tools: ReadonlyArray<Tool>, approvals: Map<string, boolean> = new Map(), clientResults: Map<string, any> = new Map(), createCustomEventChunk?: ( eventName: string, value: Record<string, any>, ) => CustomEvent, middlewareHooks?: ToolExecutionMiddlewareHooks, ): AsyncGenerator<CustomEvent, ExecuteToolCallsResult, void> { const results: Array<ToolResult> = [] const needsApproval: Array<ApprovalRequest> = [] const needsClientExecution: Array<ClientToolRequest> = [] // Create tool lookup map const toolMap = new Map<string, Tool>() for (const tool of tools) { toolMap.set(tool.name, tool) } // Batch gating: when any tool in the batch still needs an approval decision, // defer all execution so side effects don't happen before the user decides. const hasPendingApprovals = toolCalls.some((tc) => { const t = toolMap.get(tc.function.name) return t?.needsApproval && !approvals.has(`approval_${tc.id}`) }) for (const toolCall of toolCalls) { const tool = toolMap.get(toolCall.function.name) const toolName = toolCall.function.name if (!tool) { // Unknown tool - return error results.push({ toolCallId: toolCall.id, toolName, result: { error: `Unknown tool: ${toolName}` }, state: 'output-error', }) continue } // Skip non-pending tools while approvals are outstanding if (hasPendingApprovals) { if (!tool.needsApproval || approvals.has(`approval_${toolCall.id}`)) { continue } } // Parse arguments, throwing error if invalid JSON let input: unknown = {} const argsStr = toolCall.function.arguments.trim() || '{}' if (argsStr) { try { input = JSON.parse(argsStr) } catch (parseError) { // If parsing fails, throw error to fail fast throw new Error(`Failed to parse tool arguments as JSON: ${argsStr}`) } } // Validate input against inputSchema (for Standard Schema compliant schemas) if (tool.inputSchema && isStandardSchema(tool.inputSchema)) { try { input = parseWithStandardSchema(tool.inputSchema, input) } catch (validationError: unknown) { const message = validationError instanceof Error ? validationError.message : 'Validation failed' results.push({ toolCallId: toolCall.id, toolName, result: { error: `Input validation failed for tool ${tool.name}: ${message}`, }, state: 'output-error', }) continue } } // Create a ToolExecutionContext for this tool call with event emission const pendingEvents: Array<CustomEvent> = [] const context: ToolExecutionContext = { toolCallId: toolCall.id, emitCustomEvent: (eventName: string, value: Record<string, any>) => { if (createCustomEventChunk) { pendingEvents.push( createCustomEventChunk(eventName, { ...value, toolCallId: toolCall.id, }), ) } }, } // CASE 1: Client-side tool (no execute function) if (!tool.execute) { // Check if tool needs approval if (tool.needsApproval) { const approvalId = `approval_${toolCall.id}` // Check if approval decision exists if (approvals.has(approvalId)) { const approved = approvals.get(approvalId) if (approved) { // Approved - check if client has executed if (clientResults.has(toolCall.id)) { results.push({ toolCallId: toolCall.id, toolName, result: clientResults.get(toolCall.id), }) } else { // Approved but not executed yet - request client execution needsClientExecution.push({ toolCallId: toolCall.id, toolName, input, }) } } else { // User declined results.push({ toolCallId: toolCall.id, toolName, result: { error: 'User declined tool execution' }, state: 'output-error', }) } } else { // Need approval first needsApproval.push({ toolCallId: toolCall.id, toolName: toolCall.function.name, input, approvalId, }) } } else { // No approval needed - check if client has executed if (clientResults.has(toolCall.id)) { results.push({ toolCallId: toolCall.id, toolName, result: clientResults.get(toolCall.id), }) } else { // Request client execution needsClientExecution.push({ toolCallId: toolCall.id, toolName, input, }) } } continue } // CASE 2: Server tool with approval required if (tool.needsApproval) { const approvalId = `approval_${toolCall.id}` // Check if approval decision exists if (approvals.has(approvalId)) { const approved = approvals.get(approvalId) if (approved) { // Apply middleware before-hook for approved tools if (middlewareHooks) { const decision = await applyBeforeToolCallDecision( toolCall, tool, input, toolName, middlewareHooks, results, ) if (!decision.proceed) continue input = decision.input } yield* executeServerTool( toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks, ) } else { // User declined results.push({ toolCallId: toolCall.id, toolName, result: { error: 'User declined tool execution' }, state: 'output-error', }) } } else { // Need approval needsApproval.push({ toolCallId: toolCall.id, toolName, input, approvalId, }) } continue } // CASE 3: Normal server tool - execute immediately if (middlewareHooks) { const decision = await applyBeforeToolCallDecision( toolCall, tool, input, toolName, middlewareHooks, results, ) if (!decision.proceed) continue input = decision.input } yield* executeServerTool( toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks, ) } return { results, needsApproval, needsClientExecution } }