@tanstack/ai
Version:
Core TanStack AI library - Open source AI SDK
879 lines (878 loc) • 27.5 kB
JavaScript
import { devtoolsMiddleware } from "@tanstack/ai-event-client";
import { streamToText } from "../../stream-to-response.js";
import { LazyToolManager } from "./tools/lazy-tool-manager.js";
import { ToolCallManager, MiddlewareAbortError, executeToolCalls } from "./tools/tool-calls.js";
import { convertSchemaToJsonSchema, isStandardSchema, parseWithStandardSchema } from "./tools/schema-converter.js";
import { maxIterations } from "./agent-loop-strategies.js";
import { convertMessagesToModelMessages } from "./messages.js";
import { MiddlewareRunner } from "./middleware/compose.js";
const kind = "text";
function createChatOptions(options) {
return options;
}
class TextEngine {
constructor(config) {
this.iterationCount = 0;
this.lastFinishReason = null;
this.streamStartTime = 0;
this.totalChunkCount = 0;
this.currentMessageId = null;
this.accumulatedContent = "";
this.finishedEvent = null;
this.earlyTermination = false;
this.toolPhase = "continue";
this.cyclePhase = "processText";
this.deferredPromises = [];
this.terminalHookCalled = false;
this.adapter = config.adapter;
this.params = config.params;
this.systemPrompts = config.params.systemPrompts || [];
this.loopStrategy = config.params.agentLoopStrategy || maxIterations(5);
this.initialMessageCount = config.params.messages.length;
const { approvals, clientToolResults } = this.extractClientStateFromOriginalMessages(
config.params.messages
);
this.initialApprovals = approvals;
this.initialClientToolResults = clientToolResults;
this.messages = convertMessagesToModelMessages(
config.params.messages
);
this.lazyToolManager = new LazyToolManager(
config.params.tools || [],
this.messages
);
this.tools = this.lazyToolManager.getActiveTools();
this.toolCallManager = new ToolCallManager(this.tools);
this.requestId = this.createId("chat");
this.streamId = this.createId("stream");
this.effectiveRequest = config.params.abortController ? { signal: config.params.abortController.signal } : void 0;
this.effectiveSignal = config.params.abortController?.signal;
const allMiddleware = [devtoolsMiddleware(), ...config.middleware || []];
this.middlewareRunner = new MiddlewareRunner(allMiddleware);
this.middlewareAbortController = new AbortController();
this.middlewareCtx = {
requestId: this.requestId,
streamId: this.streamId,
conversationId: config.params.conversationId,
phase: "init",
iteration: 0,
chunkIndex: 0,
signal: this.effectiveSignal,
abort: (reason) => {
this.abortReason = reason;
this.middlewareAbortController?.abort(reason);
},
context: config.context,
defer: (promise) => {
this.deferredPromises.push(promise);
},
// Provider / adapter info
provider: config.adapter.name,
model: config.params.model,
source: "server",
streaming: true,
// Config-derived (updated in beforeRun and applyMiddlewareConfig)
systemPrompts: this.systemPrompts,
toolNames: void 0,
options: void 0,
modelOptions: config.params.modelOptions,
// Computed
messageCount: this.initialMessageCount,
hasTools: this.tools.length > 0,
// Mutable per-iteration
currentMessageId: null,
accumulatedContent: "",
// References
messages: this.messages,
createId: (prefix) => this.createId(prefix)
};
}
/** Get the accumulated content after the chat loop completes */
getAccumulatedContent() {
return this.accumulatedContent;
}
/** Get the final messages array after the chat loop completes */
getMessages() {
return this.messages;
}
async *run() {
this.beforeRun();
try {
this.middlewareCtx.phase = "init";
const initialConfig = this.buildMiddlewareConfig();
const transformedConfig = await this.middlewareRunner.runOnConfig(
this.middlewareCtx,
initialConfig
);
this.applyMiddlewareConfig(transformedConfig);
await this.middlewareRunner.runOnStart(this.middlewareCtx);
const pendingPhase = yield* this.checkForPendingToolCalls();
if (pendingPhase === "wait") {
return;
}
do {
if (this.earlyTermination || this.isCancelled()) {
return;
}
await this.beginCycle();
if (this.cyclePhase === "processText") {
this.middlewareCtx.phase = "beforeModel";
this.middlewareCtx.iteration = this.iterationCount;
const iterConfig = this.buildMiddlewareConfig();
const transformedConfig2 = await this.middlewareRunner.runOnConfig(
this.middlewareCtx,
iterConfig
);
this.applyMiddlewareConfig(transformedConfig2);
yield* this.streamModelResponse();
} else {
yield* this.processToolCalls();
}
this.endCycle();
} while (this.shouldContinue());
if (!this.terminalHookCalled && this.toolPhase !== "wait") {
this.terminalHookCalled = true;
await this.middlewareRunner.runOnFinish(this.middlewareCtx, {
finishReason: this.lastFinishReason,
duration: Date.now() - this.streamStartTime,
content: this.accumulatedContent,
usage: this.finishedEvent?.usage
});
}
} catch (error) {
if (!this.terminalHookCalled) {
this.terminalHookCalled = true;
if (error instanceof MiddlewareAbortError) {
this.abortReason = error.message;
await this.middlewareRunner.runOnAbort(this.middlewareCtx, {
reason: error.message,
duration: Date.now() - this.streamStartTime
});
} else {
await this.middlewareRunner.runOnError(this.middlewareCtx, {
error,
duration: Date.now() - this.streamStartTime
});
}
}
if (!(error instanceof MiddlewareAbortError)) {
throw error;
}
} finally {
if (!this.terminalHookCalled && this.isCancelled()) {
this.terminalHookCalled = true;
await this.middlewareRunner.runOnAbort(this.middlewareCtx, {
reason: this.abortReason,
duration: Date.now() - this.streamStartTime
});
}
if (this.deferredPromises.length > 0) {
await Promise.allSettled(this.deferredPromises);
}
}
}
beforeRun() {
this.streamStartTime = Date.now();
const { tools, temperature, topP, maxTokens, metadata } = this.params;
const options = {};
if (temperature !== void 0) options.temperature = temperature;
if (topP !== void 0) options.topP = topP;
if (maxTokens !== void 0) options.maxTokens = maxTokens;
if (metadata !== void 0) options.metadata = metadata;
this.eventOptions = Object.keys(options).length > 0 ? options : void 0;
this.eventToolNames = tools?.map((t) => t.name);
this.middlewareCtx.options = this.eventOptions;
this.middlewareCtx.toolNames = this.eventToolNames;
}
async beginCycle() {
if (this.cyclePhase === "processText") {
await this.beginIteration();
}
}
endCycle() {
if (this.cyclePhase === "processText") {
this.cyclePhase = "executeToolCalls";
return;
}
this.cyclePhase = "processText";
this.iterationCount++;
}
async beginIteration() {
this.currentMessageId = this.createId("msg");
this.accumulatedContent = "";
this.finishedEvent = null;
this.middlewareCtx.currentMessageId = this.currentMessageId;
this.middlewareCtx.accumulatedContent = "";
await this.middlewareRunner.runOnIteration(this.middlewareCtx, {
iteration: this.iterationCount,
messageId: this.currentMessageId
});
}
async *streamModelResponse() {
const { temperature, topP, maxTokens, metadata, modelOptions } = this.params;
const tools = this.tools;
const toolsWithJsonSchemas = tools.map((tool) => ({
...tool,
inputSchema: tool.inputSchema ? convertSchemaToJsonSchema(tool.inputSchema) : void 0,
outputSchema: tool.outputSchema ? convertSchemaToJsonSchema(tool.outputSchema) : void 0
}));
this.middlewareCtx.phase = "modelStream";
for await (const chunk of this.adapter.chatStream({
model: this.params.model,
messages: this.messages,
tools: toolsWithJsonSchemas,
temperature,
topP,
maxTokens,
metadata,
request: this.effectiveRequest,
modelOptions,
systemPrompts: this.systemPrompts
})) {
if (this.isCancelled()) {
break;
}
this.totalChunkCount++;
const outputChunks = await this.middlewareRunner.runOnChunk(
this.middlewareCtx,
chunk
);
for (const outputChunk of outputChunks) {
yield outputChunk;
this.handleStreamChunk(outputChunk);
this.middlewareCtx.chunkIndex++;
}
if (chunk.type === "RUN_FINISHED" && chunk.usage) {
await this.middlewareRunner.runOnUsage(this.middlewareCtx, chunk.usage);
}
if (this.earlyTermination) {
break;
}
}
}
handleStreamChunk(chunk) {
switch (chunk.type) {
// AG-UI Events
case "TEXT_MESSAGE_CONTENT":
this.handleTextMessageContentEvent(chunk);
break;
case "TOOL_CALL_START":
this.handleToolCallStartEvent(chunk);
break;
case "TOOL_CALL_ARGS":
this.handleToolCallArgsEvent(chunk);
break;
case "TOOL_CALL_END":
this.handleToolCallEndEvent(chunk);
break;
case "RUN_FINISHED":
this.handleRunFinishedEvent(chunk);
break;
case "RUN_ERROR":
this.handleRunErrorEvent(chunk);
break;
case "STEP_FINISHED":
this.handleStepFinishedEvent(chunk);
break;
}
}
// ===========================
// AG-UI Event Handlers
// ===========================
handleTextMessageContentEvent(chunk) {
if (chunk.content) {
this.accumulatedContent = chunk.content;
} else {
this.accumulatedContent += chunk.delta;
}
this.middlewareCtx.accumulatedContent = this.accumulatedContent;
}
handleToolCallStartEvent(chunk) {
this.toolCallManager.addToolCallStartEvent(chunk);
}
handleToolCallArgsEvent(chunk) {
this.toolCallManager.addToolCallArgsEvent(chunk);
}
handleToolCallEndEvent(chunk) {
this.toolCallManager.completeToolCall(chunk);
}
handleRunFinishedEvent(chunk) {
this.finishedEvent = chunk;
this.lastFinishReason = chunk.finishReason;
}
handleRunErrorEvent(_chunk) {
this.earlyTermination = true;
}
handleStepFinishedEvent(_chunk) {
}
async *checkForPendingToolCalls() {
const pendingToolCalls = this.getPendingToolCallsFromMessages();
if (pendingToolCalls.length === 0) {
return "continue";
}
const finishEvent = this.createSyntheticFinishedEvent();
const undiscoveredLazyResults = [];
const executablePendingCalls = pendingToolCalls.filter((tc) => {
if (this.lazyToolManager.isUndiscoveredLazyTool(tc.function.name)) {
undiscoveredLazyResults.push({
toolCallId: tc.id,
toolName: tc.function.name,
result: {
error: this.lazyToolManager.getUndiscoveredToolError(
tc.function.name
)
},
state: "output-error"
});
return false;
}
return true;
});
if (undiscoveredLazyResults.length > 0) {
for (const chunk of this.buildToolResultChunks(
undiscoveredLazyResults,
finishEvent
)) {
yield chunk;
}
}
if (executablePendingCalls.length === 0) {
return "continue";
}
const { approvals, clientToolResults } = this.collectClientState();
const generator = executeToolCalls(
executablePendingCalls,
this.tools,
approvals,
clientToolResults,
(eventName, data) => this.createCustomEventChunk(eventName, data),
{
onBeforeToolCall: async (toolCall, tool, args) => {
const hookCtx = {
toolCall,
tool,
args,
toolName: toolCall.function.name,
toolCallId: toolCall.id
};
return this.middlewareRunner.runOnBeforeToolCall(
this.middlewareCtx,
hookCtx
);
},
onAfterToolCall: async (info) => {
await this.middlewareRunner.runOnAfterToolCall(
this.middlewareCtx,
info
);
}
}
);
const executionResult = yield* this.drainToolCallGenerator(generator);
if (this.isMiddlewareAborted()) {
this.setToolPhase("stop");
return "stop";
}
await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, {
toolCalls: pendingToolCalls,
results: executionResult.results,
needsApproval: executionResult.needsApproval,
needsClientExecution: executionResult.needsClientExecution
});
const argsMap = /* @__PURE__ */ new Map();
for (const tc of pendingToolCalls) {
argsMap.set(tc.id, tc.function.arguments);
}
if (executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0) {
if (executionResult.results.length > 0) {
for (const chunk of this.buildToolResultChunks(
executionResult.results,
finishEvent,
argsMap
)) {
yield chunk;
}
}
for (const chunk of this.buildApprovalChunks(
executionResult.needsApproval,
finishEvent
)) {
yield chunk;
}
for (const chunk of this.buildClientToolChunks(
executionResult.needsClientExecution,
finishEvent
)) {
yield chunk;
}
this.setToolPhase("wait");
return "wait";
}
const toolResultChunks = this.buildToolResultChunks(
executionResult.results,
finishEvent,
argsMap
);
for (const chunk of toolResultChunks) {
yield chunk;
}
return "continue";
}
async *processToolCalls() {
if (!this.shouldExecuteToolPhase()) {
this.setToolPhase("stop");
return;
}
const toolCalls = this.toolCallManager.getToolCalls();
const finishEvent = this.finishedEvent;
if (!finishEvent || toolCalls.length === 0) {
this.setToolPhase("stop");
return;
}
this.addAssistantToolCallMessage(toolCalls);
const undiscoveredLazyResults = [];
const executableToolCalls = toolCalls.filter((tc) => {
if (this.lazyToolManager.isUndiscoveredLazyTool(tc.function.name)) {
undiscoveredLazyResults.push({
toolCallId: tc.id,
toolName: tc.function.name,
result: {
error: this.lazyToolManager.getUndiscoveredToolError(
tc.function.name
)
},
state: "output-error"
});
return false;
}
return true;
});
if (undiscoveredLazyResults.length > 0) {
const finishEvt = this.finishedEvent;
for (const chunk of this.buildToolResultChunks(
undiscoveredLazyResults,
finishEvt
)) {
yield chunk;
}
}
if (executableToolCalls.length === 0) {
this.toolCallManager.clear();
this.setToolPhase("continue");
return;
}
this.middlewareCtx.phase = "beforeTools";
const { approvals, clientToolResults } = this.collectClientState();
const generator = executeToolCalls(
executableToolCalls,
this.tools,
approvals,
clientToolResults,
(eventName, data) => this.createCustomEventChunk(eventName, data),
{
onBeforeToolCall: async (toolCall, tool, args) => {
const hookCtx = {
toolCall,
tool,
args,
toolName: toolCall.function.name,
toolCallId: toolCall.id
};
return this.middlewareRunner.runOnBeforeToolCall(
this.middlewareCtx,
hookCtx
);
},
onAfterToolCall: async (info) => {
await this.middlewareRunner.runOnAfterToolCall(
this.middlewareCtx,
info
);
}
}
);
const executionResult = yield* this.drainToolCallGenerator(generator);
this.middlewareCtx.phase = "afterTools";
if (this.isMiddlewareAborted()) {
this.setToolPhase("stop");
return;
}
await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, {
toolCalls,
results: executionResult.results,
needsApproval: executionResult.needsApproval,
needsClientExecution: executionResult.needsClientExecution
});
if (executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0) {
if (executionResult.results.length > 0) {
for (const chunk of this.buildToolResultChunks(
executionResult.results,
finishEvent
)) {
yield chunk;
}
}
for (const chunk of this.buildApprovalChunks(
executionResult.needsApproval,
finishEvent
)) {
yield chunk;
}
for (const chunk of this.buildClientToolChunks(
executionResult.needsClientExecution,
finishEvent
)) {
yield chunk;
}
this.setToolPhase("wait");
return;
}
const toolResultChunks = this.buildToolResultChunks(
executionResult.results,
finishEvent
);
for (const chunk of toolResultChunks) {
yield chunk;
}
if (this.lazyToolManager.hasNewlyDiscoveredTools()) {
this.tools = this.lazyToolManager.getActiveTools();
this.toolCallManager = new ToolCallManager(this.tools);
this.setToolPhase("continue");
return;
}
this.toolCallManager.clear();
this.setToolPhase("continue");
}
shouldExecuteToolPhase() {
return this.finishedEvent?.finishReason === "tool_calls" && this.tools.length > 0 && this.toolCallManager.hasToolCalls();
}
addAssistantToolCallMessage(toolCalls) {
this.messages = [
...this.messages,
{
role: "assistant",
content: this.accumulatedContent || null,
toolCalls
}
];
}
/**
* Extract client state (approvals and client tool results) from original messages.
* This is called in the constructor BEFORE converting to ModelMessage format,
* because the parts array (which contains approval state) is lost during conversion.
*/
extractClientStateFromOriginalMessages(originalMessages) {
const approvals = /* @__PURE__ */ new Map();
const clientToolResults = /* @__PURE__ */ new Map();
for (const message of originalMessages) {
if (message.role === "assistant" && message.parts) {
for (const part of message.parts) {
if (part.type === "tool-call") {
if (part.output !== void 0 && !part.approval) {
clientToolResults.set(part.id, part.output);
}
if (part.approval?.id && part.approval?.approved !== void 0 && part.state === "approval-responded") {
approvals.set(part.approval.id, part.approval.approved);
}
}
}
}
}
return { approvals, clientToolResults };
}
collectClientState() {
const approvals = new Map(this.initialApprovals);
const clientToolResults = new Map(this.initialClientToolResults);
for (const message of this.messages) {
if (message.role === "tool" && message.toolCallId) {
let output;
try {
output = JSON.parse(message.content);
} catch {
output = message.content;
}
if (output && typeof output === "object" && output.pendingExecution === true) {
continue;
}
clientToolResults.set(message.toolCallId, output);
}
}
return { approvals, clientToolResults };
}
buildApprovalChunks(approvals, finishEvent) {
const chunks = [];
for (const approval of approvals) {
chunks.push({
type: "CUSTOM",
timestamp: Date.now(),
model: finishEvent.model,
name: "approval-requested",
value: {
toolCallId: approval.toolCallId,
toolName: approval.toolName,
input: approval.input,
approval: {
id: approval.approvalId,
needsApproval: true
}
}
});
}
return chunks;
}
buildClientToolChunks(clientRequests, finishEvent) {
const chunks = [];
for (const clientTool of clientRequests) {
chunks.push({
type: "CUSTOM",
timestamp: Date.now(),
model: finishEvent.model,
name: "tool-input-available",
value: {
toolCallId: clientTool.toolCallId,
toolName: clientTool.toolName,
input: clientTool.input
}
});
}
return chunks;
}
buildToolResultChunks(results, finishEvent, argsMap) {
const chunks = [];
for (const result of results) {
const content = JSON.stringify(result.result);
if (argsMap) {
chunks.push({
type: "TOOL_CALL_START",
timestamp: Date.now(),
model: finishEvent.model,
toolCallId: result.toolCallId,
toolName: result.toolName
});
const args = argsMap.get(result.toolCallId) ?? "{}";
chunks.push({
type: "TOOL_CALL_ARGS",
timestamp: Date.now(),
model: finishEvent.model,
toolCallId: result.toolCallId,
delta: args,
args
});
}
chunks.push({
type: "TOOL_CALL_END",
timestamp: Date.now(),
model: finishEvent.model,
toolCallId: result.toolCallId,
toolName: result.toolName,
result: content
});
this.messages = [
...this.messages,
{
role: "tool",
content,
toolCallId: result.toolCallId
}
];
}
return chunks;
}
getPendingToolCallsFromMessages() {
const completedToolIds = /* @__PURE__ */ new Set();
for (const message of this.messages) {
if (message.role === "tool" && message.toolCallId) {
let hasPendingExecution = false;
if (typeof message.content === "string") {
try {
const parsed = JSON.parse(message.content);
if (parsed.pendingExecution === true) {
hasPendingExecution = true;
}
} catch {
}
}
if (!hasPendingExecution) {
completedToolIds.add(message.toolCallId);
}
}
}
const pending = [];
for (const message of this.messages) {
if (message.role === "assistant" && message.toolCalls) {
for (const toolCall of message.toolCalls) {
if (!completedToolIds.has(toolCall.id)) {
pending.push(toolCall);
}
}
}
}
return pending;
}
createSyntheticFinishedEvent() {
return {
type: "RUN_FINISHED",
runId: this.createId("pending"),
model: this.params.model,
timestamp: Date.now(),
finishReason: "tool_calls"
};
}
shouldContinue() {
if (this.cyclePhase === "executeToolCalls") {
return true;
}
return this.loopStrategy({
iterationCount: this.iterationCount,
messages: this.messages,
finishReason: this.lastFinishReason
}) && this.toolPhase === "continue";
}
isAborted() {
return !!this.effectiveSignal?.aborted;
}
isMiddlewareAborted() {
return !!this.middlewareAbortController?.signal.aborted;
}
isCancelled() {
return this.isAborted() || this.isMiddlewareAborted();
}
buildMiddlewareConfig() {
return {
messages: this.messages,
systemPrompts: [...this.systemPrompts],
tools: [...this.tools],
temperature: this.params.temperature,
topP: this.params.topP,
maxTokens: this.params.maxTokens,
metadata: this.params.metadata,
modelOptions: this.params.modelOptions
};
}
applyMiddlewareConfig(config) {
this.messages = config.messages;
this.systemPrompts = config.systemPrompts;
this.tools = config.tools;
this.params = {
...this.params,
temperature: config.temperature,
topP: config.topP,
maxTokens: config.maxTokens,
metadata: config.metadata,
modelOptions: config.modelOptions
};
this.middlewareCtx.messages = this.messages;
this.middlewareCtx.systemPrompts = this.systemPrompts;
this.middlewareCtx.hasTools = this.tools.length > 0;
this.middlewareCtx.toolNames = this.tools.map((t) => t.name);
this.middlewareCtx.modelOptions = config.modelOptions;
}
setToolPhase(phase) {
this.toolPhase = phase;
}
/**
* Drain an executeToolCalls async generator, yielding any CustomEvent chunks
* and returning the final ExecuteToolCallsResult.
*/
async *drainToolCallGenerator(generator) {
let next = await generator.next();
while (!next.done) {
yield next.value;
next = await generator.next();
}
return next.value;
}
createCustomEventChunk(eventName, value) {
return {
type: "CUSTOM",
timestamp: Date.now(),
model: this.params.model,
name: eventName,
value
};
}
createId(prefix) {
return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
}
}
function chat(options) {
const { outputSchema, stream } = options;
if (outputSchema) {
return runAgenticStructuredOutput(
options
);
}
if (stream === false) {
return runNonStreamingText(
options
);
}
return runStreamingText(
options
);
}
async function* runStreamingText(options) {
const { adapter, middleware, context, ...textOptions } = options;
const model = adapter.model;
const engine = new TextEngine({
adapter,
params: { ...textOptions, model },
middleware,
context
});
for await (const chunk of engine.run()) {
yield chunk;
}
}
function runNonStreamingText(options) {
const stream = runStreamingText(
options
);
return streamToText(stream);
}
async function runAgenticStructuredOutput(options) {
const { adapter, outputSchema, middleware, context, ...textOptions } = options;
const model = adapter.model;
if (!outputSchema) {
throw new Error("outputSchema is required for structured output");
}
const engine = new TextEngine({
adapter,
params: { ...textOptions, model },
middleware,
context
});
for await (const _chunk of engine.run()) {
}
const finalMessages = engine.getMessages();
const {
tools: _tools,
agentLoopStrategy: _als,
...structuredTextOptions
} = textOptions;
const jsonSchema = convertSchemaToJsonSchema(outputSchema);
if (!jsonSchema) {
throw new Error("Failed to convert output schema to JSON Schema");
}
const result = await adapter.structuredOutput({
chatOptions: {
...structuredTextOptions,
model,
messages: finalMessages
},
outputSchema: jsonSchema
});
if (isStandardSchema(outputSchema)) {
return parseWithStandardSchema(
outputSchema,
result.data
);
}
return result.data;
}
export {
chat,
createChatOptions,
kind
};
//# sourceMappingURL=index.js.map