@tanstack/ai
Version:
Core TanStack AI library - Open source AI SDK
460 lines (459 loc) • 13.6 kB
JavaScript
import { isStandardSchema, parseWithStandardSchema } from "./schema-converter.js";
function safeJsonParse(value) {
try {
return JSON.parse(value);
} catch {
return value;
}
}
class MiddlewareAbortError extends Error {
constructor(reason) {
super(reason);
this.name = "MiddlewareAbortError";
}
}
class ToolCallManager {
constructor(tools) {
this.toolCallsMap = /* @__PURE__ */ new Map();
this.tools = tools;
}
/**
* Add a TOOL_CALL_START event to begin tracking a tool call (AG-UI)
*/
addToolCallStartEvent(event) {
const index = event.index ?? this.toolCallsMap.size;
this.toolCallsMap.set(index, {
id: event.toolCallId,
type: "function",
function: {
name: event.toolName,
arguments: ""
},
...event.providerMetadata && {
providerMetadata: event.providerMetadata
}
});
}
/**
* Add a TOOL_CALL_ARGS event to accumulate arguments (AG-UI)
*/
addToolCallArgsEvent(event) {
for (const [, toolCall] of this.toolCallsMap.entries()) {
if (toolCall.id === event.toolCallId) {
toolCall.function.arguments += event.delta;
break;
}
}
}
/**
* Complete a tool call with its final input
* Called when TOOL_CALL_END is received
*/
completeToolCall(event) {
for (const [, toolCall] of this.toolCallsMap.entries()) {
if (toolCall.id === event.toolCallId) {
if (event.input !== void 0) {
toolCall.function.arguments = JSON.stringify(event.input);
}
break;
}
}
}
/**
* Check if there are any complete tool calls to execute
*/
hasToolCalls() {
return this.getToolCalls().length > 0;
}
/**
* Get all complete tool calls (filtered for valid ID and name)
*/
getToolCalls() {
return Array.from(this.toolCallsMap.values()).filter(
(tc) => tc.id && tc.function.name && tc.function.name.trim().length > 0
);
}
/**
* Execute all tool calls and return tool result messages
* Yields TOOL_CALL_END events for streaming
* @param finishEvent - RUN_FINISHED event from the stream
*/
async *executeTools(finishEvent) {
const toolCallsArray = this.getToolCalls();
const toolResults = [];
for (const toolCall of toolCallsArray) {
const tool = this.tools.find((t) => t.name === toolCall.function.name);
let toolResultContent;
if (tool?.execute) {
try {
let args;
try {
const argsString = toolCall.function.arguments.trim() || "{}";
args = JSON.parse(argsString === "null" ? "{}" : argsString);
} catch (parseError) {
throw new Error(
`Failed to parse tool arguments as JSON: ${toolCall.function.arguments}`
);
}
if (tool.inputSchema && isStandardSchema(tool.inputSchema)) {
try {
args = parseWithStandardSchema(tool.inputSchema, args);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
throw new Error(
`Input validation failed for tool ${tool.name}: ${message}`
);
}
}
let result = await tool.execute(args);
if (tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== void 0 && result !== null) {
try {
result = parseWithStandardSchema(tool.outputSchema, result);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
throw new Error(
`Output validation failed for tool ${tool.name}: ${message}`
);
}
}
toolResultContent = typeof result === "string" ? result : JSON.stringify(result);
} catch (error) {
const message = error instanceof Error ? error.message : "Unknown error";
toolResultContent = `Error executing tool: ${message}`;
}
} else {
toolResultContent = `Tool ${toolCall.function.name} does not have an execute function`;
}
yield {
type: "TOOL_CALL_END",
toolCallId: toolCall.id,
toolName: toolCall.function.name,
model: finishEvent.model,
timestamp: Date.now(),
result: toolResultContent
};
toolResults.push({
role: "tool",
content: toolResultContent,
toolCallId: toolCall.id
});
}
return toolResults;
}
/**
* Clear the tool calls map for the next iteration
*/
clear() {
this.toolCallsMap.clear();
}
}
async function* executeWithEventPolling(executionPromise, pendingEvents) {
const state = { done: false, result: void 0 };
const executionWithFlag = executionPromise.then((r) => {
state.done = true;
state.result = r;
return r;
});
while (!state.done) {
await Promise.race([
executionWithFlag,
new Promise((resolve) => setTimeout(resolve, 10))
]);
while (pendingEvents.length > 0) {
yield pendingEvents.shift();
}
}
while (pendingEvents.length > 0) {
yield pendingEvents.shift();
}
return state.result;
}
async function applyBeforeToolCallDecision(toolCall, tool, input, toolName, middlewareHooks, results) {
if (!middlewareHooks.onBeforeToolCall) {
return { proceed: true, input };
}
const decision = await middlewareHooks.onBeforeToolCall(toolCall, tool, input);
if (!decision) {
return { proceed: true, input };
}
if (decision.type === "abort") {
throw new MiddlewareAbortError(decision.reason || "Aborted by middleware");
}
if (decision.type === "skip") {
const skipResult = decision.result;
results.push({
toolCallId: toolCall.id,
toolName,
result: typeof skipResult === "string" ? safeJsonParse(skipResult) : skipResult || null,
duration: 0
});
if (middlewareHooks.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: true,
duration: 0,
result: skipResult
});
}
return { proceed: false };
}
return { proceed: true, input: decision.args };
}
async function* executeServerTool(toolCall, tool, toolName, input, context, pendingEvents, results, middlewareHooks) {
const startTime = Date.now();
try {
const executionPromise = Promise.resolve(tool.execute(input, context));
let result = yield* executeWithEventPolling(executionPromise, pendingEvents);
const duration = Date.now() - startTime;
while (pendingEvents.length > 0) {
yield pendingEvents.shift();
}
if (tool.outputSchema && isStandardSchema(tool.outputSchema) && result !== void 0 && result !== null) {
result = parseWithStandardSchema(tool.outputSchema, result);
}
const finalResult = typeof result === "string" ? safeJsonParse(result) : result || null;
results.push({
toolCallId: toolCall.id,
toolName,
result: finalResult,
duration
});
if (middlewareHooks?.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: true,
duration,
result: finalResult
});
}
} catch (error) {
const duration = Date.now() - startTime;
while (pendingEvents.length > 0) {
yield pendingEvents.shift();
}
if (error instanceof MiddlewareAbortError) {
throw error;
}
const message = error instanceof Error ? error.message : "Unknown error";
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: message },
state: "output-error",
duration
});
if (middlewareHooks?.onAfterToolCall) {
await middlewareHooks.onAfterToolCall({
toolCall,
tool,
toolName,
toolCallId: toolCall.id,
ok: false,
duration,
error
});
}
}
}
async function* executeToolCalls(toolCalls, tools, approvals = /* @__PURE__ */ new Map(), clientResults = /* @__PURE__ */ new Map(), createCustomEventChunk, middlewareHooks) {
const results = [];
const needsApproval = [];
const needsClientExecution = [];
const toolMap = /* @__PURE__ */ new Map();
for (const tool of tools) {
toolMap.set(tool.name, tool);
}
const hasPendingApprovals = toolCalls.some((tc) => {
const t = toolMap.get(tc.function.name);
return t?.needsApproval && !approvals.has(`approval_${tc.id}`);
});
for (const toolCall of toolCalls) {
const tool = toolMap.get(toolCall.function.name);
const toolName = toolCall.function.name;
if (!tool) {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: `Unknown tool: ${toolName}` },
state: "output-error"
});
continue;
}
if (hasPendingApprovals) {
if (!tool.needsApproval || approvals.has(`approval_${toolCall.id}`)) {
continue;
}
}
let input = {};
const argsStr = toolCall.function.arguments.trim() || "{}";
{
try {
input = JSON.parse(argsStr);
} catch (parseError) {
throw new Error(`Failed to parse tool arguments as JSON: ${argsStr}`);
}
}
if (tool.inputSchema && isStandardSchema(tool.inputSchema)) {
try {
input = parseWithStandardSchema(tool.inputSchema, input);
} catch (validationError) {
const message = validationError instanceof Error ? validationError.message : "Validation failed";
results.push({
toolCallId: toolCall.id,
toolName,
result: {
error: `Input validation failed for tool ${tool.name}: ${message}`
},
state: "output-error"
});
continue;
}
}
const pendingEvents = [];
const context = {
toolCallId: toolCall.id,
emitCustomEvent: (eventName, value) => {
if (createCustomEventChunk) {
pendingEvents.push(
createCustomEventChunk(eventName, {
...value,
toolCallId: toolCall.id
})
);
}
}
};
if (!tool.execute) {
if (tool.needsApproval) {
const approvalId = `approval_${toolCall.id}`;
if (approvals.has(approvalId)) {
const approved = approvals.get(approvalId);
if (approved) {
if (clientResults.has(toolCall.id)) {
results.push({
toolCallId: toolCall.id,
toolName,
result: clientResults.get(toolCall.id)
});
} else {
needsClientExecution.push({
toolCallId: toolCall.id,
toolName,
input
});
}
} else {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: "User declined tool execution" },
state: "output-error"
});
}
} else {
needsApproval.push({
toolCallId: toolCall.id,
toolName: toolCall.function.name,
input,
approvalId
});
}
} else {
if (clientResults.has(toolCall.id)) {
results.push({
toolCallId: toolCall.id,
toolName,
result: clientResults.get(toolCall.id)
});
} else {
needsClientExecution.push({
toolCallId: toolCall.id,
toolName,
input
});
}
}
continue;
}
if (tool.needsApproval) {
const approvalId = `approval_${toolCall.id}`;
if (approvals.has(approvalId)) {
const approved = approvals.get(approvalId);
if (approved) {
if (middlewareHooks) {
const decision = await applyBeforeToolCallDecision(
toolCall,
tool,
input,
toolName,
middlewareHooks,
results
);
if (!decision.proceed) continue;
input = decision.input;
}
yield* executeServerTool(
toolCall,
tool,
toolName,
input,
context,
pendingEvents,
results,
middlewareHooks
);
} else {
results.push({
toolCallId: toolCall.id,
toolName,
result: { error: "User declined tool execution" },
state: "output-error"
});
}
} else {
needsApproval.push({
toolCallId: toolCall.id,
toolName,
input,
approvalId
});
}
continue;
}
if (middlewareHooks) {
const decision = await applyBeforeToolCallDecision(
toolCall,
tool,
input,
toolName,
middlewareHooks,
results
);
if (!decision.proceed) continue;
input = decision.input;
}
yield* executeServerTool(
toolCall,
tool,
toolName,
input,
context,
pendingEvents,
results,
middlewareHooks
);
}
return { results, needsApproval, needsClientExecution };
}
export {
MiddlewareAbortError,
ToolCallManager,
executeToolCalls
};
//# sourceMappingURL=tool-calls.js.map