UNPKG

@tanstack/ai

Version:

Core TanStack AI library - Open source AI SDK

460 lines (459 loc) 13.6 kB
import { isStandardSchema, parseWithStandardSchema } from "./schema-converter.js"; function safeJsonParse(value) { try { return JSON.parse(value); } catch { return value; } } class MiddlewareAbortError extends Error { constructor(reason) { super(reason); this.name = "MiddlewareAbortError"; } } class ToolCallManager { constructor(tools) { this.toolCallsMap = /* @__PURE__ */ new Map(); this.tools = tools; } /** * Add a TOOL_CALL_START event to begin tracking a tool call (AG-UI) */ addToolCallStartEvent(event) { const index = event.index ?? this.toolCallsMap.size; this.toolCallsMap.set(index, { id: event.toolCallId, type: "function", function: { name: event.toolName, arguments: "" }, ...event.providerMetadata && { providerMetadata: event.providerMetadata } }); } /** * Add a TOOL_CALL_ARGS event to accumulate arguments (AG-UI) */ addToolCallArgsEvent(event) { for (const [, toolCall] of this.toolCallsMap.entries()) { if (toolCall.id === event.toolCallId) { toolCall.function.arguments += event.delta; break; } } } /** * Complete a tool call with its final input * Called when TOOL_CALL_END is received */ completeToolCall(event) { for (const [, toolCall] of this.toolCallsMap.entries()) { if (toolCall.id === event.toolCallId) { if (event.input !== void 0) { toolCall.function.arguments = JSON.stringify(event.input); } break; } } } /** * Check if there are any complete tool calls to execute */ hasToolCalls() { return this.getToolCalls().length > 0; } /** * Get all complete tool calls (filtered for valid ID and name) */ getToolCalls() { return Array.from(this.toolCallsMap.values()).filter( (tc) => tc.id && tc.function.name && tc.function.name.trim().length > 0 ); } /** * Execute all tool calls and return tool result messages * Yields TOOL_CALL_END events for streaming * @param finishEvent - RUN_FINISHED event from the stream */ async *executeTools(finishEvent) { const toolCallsArray = this.getToolCalls(); const toolResults = []; for (const toolCall of toolCallsArray) { const tool = this.tools.find((t) => t.name === toolCall.function.name); let toolResultContent; if (tool?.execute) { try { let args; try { const argsString = toolCall.function.arguments.trim() || "{}"; args = JSON.parse(argsString === "null" ? "{}" : argsString); } catch (parseError) { throw new Error( `Failed to parse tool arguments as JSON: ${toolCall.function.arguments}` ); } if (tool.inputSchema && isStandardSchema(tool.inputSchema)) { try { args = parseWithStandardSchema(tool.inputSchema, args); } catch (validationError) { const message = validationError instanceof Error ? validationError.message : "Validation failed"; throw new Error( `Input validation failed for tool ${tool.name}: ${message}` ); } } let result = await tool.execute(args); if (tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== void 0 && result !== null) { try { result = parseWithStandardSchema(tool.outputSchema, result); } catch (validationError) { const message = validationError instanceof Error ? validationError.message : "Validation failed"; throw new Error( `Output validation failed for tool ${tool.name}: ${message}` ); } } toolResultContent = typeof result === "string" ? result : JSON.stringify(result); } catch (error) { const message = error instanceof Error ? error.message : "Unknown error"; toolResultContent = `Error executing tool: ${message}`; } } else { toolResultContent = `Tool ${toolCall.function.name} does not have an execute function`; } yield { type: "TOOL_CALL_END", toolCallId: toolCall.id, toolName: toolCall.function.name, model: finishEvent.model, timestamp: Date.now(), result: toolResultContent }; toolResults.push({ role: "tool", content: toolResultContent, toolCallId: toolCall.id }); } return toolResults; } /** * Clear the tool calls map for the next iteration */ clear() { this.toolCallsMap.clear(); } } async function* executeWithEventPolling(executionPromise, pendingEvents) { const state = { done: false, result: void 0 }; const executionWithFlag = executionPromise.then((r) => { state.done = true; state.result = r; return r; }); while (!state.done) { await Promise.race([ executionWithFlag, new Promise((resolve) => setTimeout(resolve, 10)) ]); while (pendingEvents.length > 0) { yield pendingEvents.shift(); } } while (pendingEvents.length > 0) { yield pendingEvents.shift(); } return state.result; } async function applyBeforeToolCallDecision(toolCall, tool, input, toolName, middlewareHooks, results) { if (!middlewareHooks.onBeforeToolCall) { return { proceed: true, input }; } const decision = await middlewareHooks.onBeforeToolCall(toolCall, tool, input); if (!decision) { return { proceed: true, input }; } if (decision.type === "abort") { throw new MiddlewareAbortError(decision.reason || "Aborted by middleware"); } if (decision.type === "skip") { const skipResult = decision.result; results.push({ toolCallId: toolCall.id, toolName, result: typeof skipResult === "string" ? safeJsonParse(skipResult) : skipResult || null, duration: 0 }); if (middlewareHooks.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: true, duration: 0, result: skipResult }); } return { proceed: false }; } return { proceed: true, input: decision.args }; } async function* executeServerTool(toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks) { const startTime = Date.now(); try { const executionPromise = Promise.resolve(tool.execute(input, context)); let result = yield* executeWithEventPolling(executionPromise, pendingEvents); const duration = Date.now() - startTime; while (pendingEvents.length > 0) { yield pendingEvents.shift(); } if (tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== void 0 && result !== null) { result = parseWithStandardSchema(tool.outputSchema, result); } const finalResult = typeof result === "string" ? safeJsonParse(result) : result || null; results.push({ toolCallId: toolCall.id, toolName, result: finalResult, duration }); if (middlewareHooks?.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: true, duration, result: finalResult }); } } catch (error) { const duration = Date.now() - startTime; while (pendingEvents.length > 0) { yield pendingEvents.shift(); } if (error instanceof MiddlewareAbortError) { throw error; } const message = error instanceof Error ? error.message : "Unknown error"; results.push({ toolCallId: toolCall.id, toolName, result: { error: message }, state: "output-error", duration }); if (middlewareHooks?.onAfterToolCall) { await middlewareHooks.onAfterToolCall({ toolCall, tool, toolName, toolCallId: toolCall.id, ok: false, duration, error }); } } } async function* executeToolCalls(toolCalls, tools, approvals = /* @__PURE__ */ new Map(), clientResults = /* @__PURE__ */ new Map(), createCustomEventChunk, middlewareHooks) { const results = []; const needsApproval = []; const needsClientExecution = []; const toolMap = /* @__PURE__ */ new Map(); for (const tool of tools) { toolMap.set(tool.name, tool); } const hasPendingApprovals = toolCalls.some((tc) => { const t = toolMap.get(tc.function.name); return t?.needsApproval && !approvals.has(`approval_${tc.id}`); }); for (const toolCall of toolCalls) { const tool = toolMap.get(toolCall.function.name); const toolName = toolCall.function.name; if (!tool) { results.push({ toolCallId: toolCall.id, toolName, result: { error: `Unknown tool: ${toolName}` }, state: "output-error" }); continue; } if (hasPendingApprovals) { if (!tool.needsApproval || approvals.has(`approval_${toolCall.id}`)) { continue; } } let input = {}; const argsStr = toolCall.function.arguments.trim() || "{}"; { try { input = JSON.parse(argsStr); } catch (parseError) { throw new Error(`Failed to parse tool arguments as JSON: ${argsStr}`); } } if (tool.inputSchema && isStandardSchema(tool.inputSchema)) { try { input = parseWithStandardSchema(tool.inputSchema, input); } catch (validationError) { const message = validationError instanceof Error ? validationError.message : "Validation failed"; results.push({ toolCallId: toolCall.id, toolName, result: { error: `Input validation failed for tool ${tool.name}: ${message}` }, state: "output-error" }); continue; } } const pendingEvents = []; const context = { toolCallId: toolCall.id, emitCustomEvent: (eventName, value) => { if (createCustomEventChunk) { pendingEvents.push( createCustomEventChunk(eventName, { ...value, toolCallId: toolCall.id }) ); } } }; if (!tool.execute) { if (tool.needsApproval) { const approvalId = `approval_${toolCall.id}`; if (approvals.has(approvalId)) { const approved = approvals.get(approvalId); if (approved) { if (clientResults.has(toolCall.id)) { results.push({ toolCallId: toolCall.id, toolName, result: clientResults.get(toolCall.id) }); } else { needsClientExecution.push({ toolCallId: toolCall.id, toolName, input }); } } else { results.push({ toolCallId: toolCall.id, toolName, result: { error: "User declined tool execution" }, state: "output-error" }); } } else { needsApproval.push({ toolCallId: toolCall.id, toolName: toolCall.function.name, input, approvalId }); } } else { if (clientResults.has(toolCall.id)) { results.push({ toolCallId: toolCall.id, toolName, result: clientResults.get(toolCall.id) }); } else { needsClientExecution.push({ toolCallId: toolCall.id, toolName, input }); } } continue; } if (tool.needsApproval) { const approvalId = `approval_${toolCall.id}`; if (approvals.has(approvalId)) { const approved = approvals.get(approvalId); if (approved) { if (middlewareHooks) { const decision = await applyBeforeToolCallDecision( toolCall, tool, input, toolName, middlewareHooks, results ); if (!decision.proceed) continue; input = decision.input; } yield* executeServerTool( toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks ); } else { results.push({ toolCallId: toolCall.id, toolName, result: { error: "User declined tool execution" }, state: "output-error" }); } } else { needsApproval.push({ toolCallId: toolCall.id, toolName, input, approvalId }); } continue; } if (middlewareHooks) { const decision = await applyBeforeToolCallDecision( toolCall, tool, input, toolName, middlewareHooks, results ); if (!decision.proceed) continue; input = decision.input; } yield* executeServerTool( toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks ); } return { results, needsApproval, needsClientExecution }; } export { MiddlewareAbortError, ToolCallManager, executeToolCalls }; //# sourceMappingURL=tool-calls.js.map