@tanstack/ai
Version:
Type-safe TypeScript AI SDK for streaming chat, tool calling, agents, structured outputs, and multimodal generation.
513 lines (512 loc) • 15.4 kB
JavaScript
import { normalizeToolResult } from "../../../utilities/tool-result.js";
import { isStandardSchema, parseWithStandardSchema } from "./schema-converter.js";
function safeJsonParse(value) {
try {
return JSON.parse(value);
} catch {
return value;
}
}
class MiddlewareAbortError extends Error {
constructor(reason) {
super(reason);
this.name = "MiddlewareAbortError";
}
}
class ToolCallManager {
toolCallsMap = /* @__PURE__ */ new Map();
tools;
constructor(tools) {
this.tools = tools;
}
/**
* Add a TOOL_CALL_START event to begin tracking a tool call (AG-UI)
*/
addToolCallStartEvent(event) {
const index = event.index ?? this.toolCallsMap.size;
const runtimeEvent = event;
const name = runtimeEvent.toolCallName ?? runtimeEvent.toolName;
this.toolCallsMap.set(index, {
id: event.toolCallId,
type: "function",
function: {
name,
arguments: ""
},
...event.metadata !== void 0 && { metadata: event.metadata }
});
}
/**
* Add a TOOL_CALL_ARGS event to accumulate arguments (AG-UI)
*/
addToolCallArgsEvent(event) {
for (const [, toolCall] of this.toolCallsMap.entries()) {
if (toolCall.id === event.toolCallId) {
toolCall.function.arguments += event.delta;
break;
}
}
}
/**
* Complete a tool call with its final input
* Called when TOOL_CALL_END is received
*/
completeToolCall(event) {
for (const [, toolCall] of this.toolCallsMap.entries()) {
if (toolCall.id === event.toolCallId) {
if (event.input !== void 0) {
const normalized = event.input && typeof event.input === "object" ? event.input : {};
toolCall.function.arguments = JSON.stringify(normalized);
}
break;
}
}
}
/**
* Check if there are any complete tool calls to execute
*/
hasToolCalls() {
return this.getToolCalls().length > 0;
}
/**
* Get all complete tool calls (filtered for valid ID and name)
*/
getToolCalls() {
return Array.from(this.toolCallsMap.values()).filter(
(tc) => tc.id && tc.function.name && tc.function.name.trim().length > 0
);
}
/**
* Execute all tool calls and return tool result messages
* Yields TOOL_CALL_END events for streaming
* @param finishEvent - RUN_FINISHED event from the stream
*/
async *executeTools(finishEvent, ...contextArgs) {
const toolCallsArray = this.getToolCalls();
const toolResults = [];
const hasRuntimeContext = contextArgs.length > 0;
const userContext = contextArgs[0];
for (const toolCall of toolCallsArray) {
const tool = this.tools.find((t) => t.name === toolCall.function.name);
let toolResultContent;
let toolResultState;
if (tool?.execute) {
try {
let args;
try {
const argsString = toolCall.function.arguments.trim() || "{}";
const parsed = JSON.parse(argsString);
args = parsed && typeof parsed === "object" ? parsed : {};
} catch (parseError) {
throw new Error(
`Failed to parse tool arguments as JSON: ${toolCall.function.arguments}`
);
}
if (tool.inputSchema && isStandardSchema(tool.inputSchema)) {
try {
args = parseWithStandardSchema(tool.inputSchema, args);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
throw new Error(
`Input validation failed for tool ${tool.name}: ${message}`
);
}
}
const executionContext = {
toolCallId: toolCall.id,
context: userContext,
emitCustomEvent: () => {
}
};
let result = hasRuntimeContext ? await tool.execute(args, executionContext) : await tool.execute(args);
if (tool.outputSchema && isStandardSchema(tool.outputSchema)) {
try {
result = parseWithStandardSchema(tool.outputSchema, result);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
throw new Error(
`Output validation failed for tool ${tool.name}: ${message}`
);
}
}
toolResultContent = normalizeToolResult(result);
} catch (error) {
const message = error instanceof Error ? error.message : "Unknown error";
toolResultContent = `Error executing tool: ${message}`;
toolResultState = "output-error";
}
} else {
toolResultContent = `Tool ${toolCall.function.name} does not have an execute function`;
}
yield {
type: "TOOL_CALL_END",
toolCallId: toolCall.id,
toolCallName: toolCall.function.name,
toolName: toolCall.function.name,
model: finishEvent.model,
timestamp: Date.now(),
result: toolResultContent,
...toolResultState !== void 0 && { state: toolResultState }
};
toolResults.push({
role: "tool",
content: toolResultContent,
toolCallId: toolCall.id
});
}
return toolResults;
}
/**
* Clear the tool calls map for the next iteration
*/
clear() {
this.toolCallsMap.clear();
}
}
async function* executeWithEventPolling(executionPromise, pendingEvents) {
const state = { done: false, result: void 0 };
const executionWithFlag = executionPromise.then((r) => {
state.done = true;
state.result = r;
return r;
});
while (!state.done) {
await Promise.race([
executionWithFlag,
new Promise((resolve) => setTimeout(resolve, 10))
]);
let event2;
while ((event2 = pendingEvents.shift()) !== void 0) {
yield event2;
}
}
let event;
while ((event = pendingEvents.shift()) !== void 0) {
yield event;
}
return state.result;
}
async function applyBeforeToolCallDecision(toolCall, tool, input, toolName, middlewareHooks, results) {
if (!middlewareHooks.onBeforeToolCall) {
return { proceed: true, input };
}
const decision = await middlewareHooks.onBeforeToolCall(toolCall, tool, input);
if (!decision) {
return { proceed: true, input };
}
if (decision.type === "abort") {
throw new MiddlewareAbortError(decision.reason || "Aborted by middleware");
}
if (decision.type === "skip") {
const skipResult = decision.result;
results.push({
toolCallId: toolCall.id,
toolName,
result: typeof skipResult === "string" ? safeJsonParse(skipResult) : skipResult ?? null,
duration: 0
});
if (middlewareHooks.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: true,
duration: 0,
result: skipResult
});
}
return { proceed: false };
}
return { proceed: true, input: decision.args };
}
async function* executeServerTool(toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks) {
const startTime = Date.now();
try {
if (!tool.execute) {
throw new Error(`Tool ${toolName} has no execute() implementation`);
}
const executionPromise = Promise.resolve(tool.execute(input, context));
let result = yield* executeWithEventPolling(executionPromise, pendingEvents);
const duration = Date.now() - startTime;
let pendingEvent;
while ((pendingEvent = pendingEvents.shift()) !== void 0) {
yield pendingEvent;
}
if (tool.outputSchema && isStandardSchema(tool.outputSchema)) {
result = parseWithStandardSchema(tool.outputSchema, result);
}
const finalResult = typeof result === "string" ? safeJsonParse(result) : result ?? null;
results.push({
toolCallId: toolCall.id,
toolName,
result: finalResult,
duration
});
if (middlewareHooks?.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: true,
duration,
result: finalResult
});
}
} catch (error) {
const duration = Date.now() - startTime;
let pendingEvent;
while ((pendingEvent = pendingEvents.shift()) !== void 0) {
yield pendingEvent;
}
if (error instanceof MiddlewareAbortError) {
throw error;
}
const message = error instanceof Error ? error.message : "Unknown error";
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: message },
state: "output-error",
duration
});
if (middlewareHooks?.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: false,
duration,
error
});
}
}
}
function buildClientToolResult(toolCallId, toolName, tool, rawResult) {
try {
let result = rawResult;
if (tool.outputSchema && isStandardSchema(tool.outputSchema)) {
result = parseWithStandardSchema(tool.outputSchema, result);
}
return {
toolCallId,
toolName,
result: typeof result === "string" ? safeJsonParse(result) : result ?? null
};
} catch (error) {
const message = error instanceof Error ? error.message : "Validation failed";
return {
toolCallId,
toolName,
result: { error: message },
state: "output-error"
};
}
}
async function* executeToolCalls(toolCalls, tools, approvals = /* @__PURE__ */ new Map(), clientResults = /* @__PURE__ */ new Map(), createCustomEventChunk, middlewareHooks, userContext, abortSignal) {
const results = [];
const needsApproval = [];
const needsClientExecution = [];
const toolMap = /* @__PURE__ */ new Map();
for (const tool of tools) {
toolMap.set(tool.name, tool);
}
const hasPendingApprovals = toolCalls.some((tc) => {
const t = toolMap.get(tc.function.name);
return t?.needsApproval && !approvals.has(`approval_${tc.id}`);
});
for (const toolCall of toolCalls) {
const tool = toolMap.get(toolCall.function.name);
const toolName = toolCall.function.name;
if (!tool) {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: `Unknown tool: ${toolName}` },
state: "output-error"
});
continue;
}
if (hasPendingApprovals) {
if (!tool.needsApproval || approvals.has(`approval_${toolCall.id}`)) {
continue;
}
}
let input = {};
const argsStr = toolCall.function.arguments.trim() || "{}";
{
try {
const parsed = JSON.parse(argsStr);
input = parsed && typeof parsed === "object" ? parsed : {};
} catch (parseError) {
throw new Error(`Failed to parse tool arguments as JSON: ${argsStr}`);
}
}
if (tool.inputSchema && isStandardSchema(tool.inputSchema)) {
try {
input = parseWithStandardSchema(tool.inputSchema, input);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
results.push({
toolCallId: toolCall.id,
toolName,
result: {
error: `Input validation failed for tool ${tool.name}: ${message}`
},
state: "output-error"
});
continue;
}
}
const pendingEvents = [];
const context = {
toolCallId: toolCall.id,
context: userContext,
abortSignal,
emitCustomEvent: (eventName, value) => {
if (createCustomEventChunk) {
pendingEvents.push(
createCustomEventChunk(eventName, {
...value,
toolCallId: toolCall.id
})
);
}
}
};
if (!tool.execute) {
if (tool.needsApproval) {
const approvalId = `approval_${toolCall.id}`;
if (approvals.has(approvalId)) {
const approved = approvals.get(approvalId);
if (approved) {
if (clientResults.has(toolCall.id)) {
results.push(
buildClientToolResult(
toolCall.id,
toolName,
tool,
clientResults.get(toolCall.id)
)
);
} else {
needsClientExecution.push({
toolCallId: toolCall.id,
toolName,
input
});
}
} else {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: "User declined tool execution" },
state: "output-error"
});
}
} else {
needsApproval.push({
toolCallId: toolCall.id,
toolName: toolCall.function.name,
input,
approvalId
});
}
} else {
if (clientResults.has(toolCall.id)) {
results.push(
buildClientToolResult(
toolCall.id,
toolName,
tool,
clientResults.get(toolCall.id)
)
);
} else {
needsClientExecution.push({
toolCallId: toolCall.id,
toolName,
input
});
}
}
continue;
}
if (tool.needsApproval) {
const approvalId = `approval_${toolCall.id}`;
if (approvals.has(approvalId)) {
const approved = approvals.get(approvalId);
if (approved) {
if (middlewareHooks) {
const decision = await applyBeforeToolCallDecision(
toolCall,
tool,
input,
toolName,
middlewareHooks,
results
);
if (!decision.proceed) continue;
input = decision.input;
}
yield* executeServerTool(
toolCall,
tool,
toolName,
input,
context,
pendingEvents,
results,
middlewareHooks
);
} else {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: "User declined tool execution" },
state: "output-error"
});
}
} else {
needsApproval.push({
toolCallId: toolCall.id,
toolName,
input,
approvalId
});
}
continue;
}
if (middlewareHooks) {
const decision = await applyBeforeToolCallDecision(
toolCall,
tool,
input,
toolName,
middlewareHooks,
results
);
if (!decision.proceed) continue;
input = decision.input;
}
yield* executeServerTool(
toolCall,
tool,
toolName,
input,
context,
pendingEvents,
results,
middlewareHooks
);
}
return { results, needsApproval, needsClientExecution };
}
export {
MiddlewareAbortError,
ToolCallManager,
executeToolCalls
};
//# sourceMappingURL=tool-calls.js.map