UNPKG

@tanstack/ai

Version:

Core TanStack AI library - Open source AI SDK

420 lines (377 loc) 11.6 kB
import type { ModelMessage, StreamChunk, Tool, ToolCall } from '../../../types' // =========================== // Middleware Context // =========================== /** * Phase of the chat middleware lifecycle. * - 'init': Initial config transform before the chat engine starts * - 'beforeModel': Before each adapter chatStream call (per agent iteration) * - 'modelStream': During model streaming * - 'beforeTools': Before tool execution phase * - 'afterTools': After tool execution phase */ export type ChatMiddlewarePhase = | 'init' | 'beforeModel' | 'modelStream' | 'beforeTools' | 'afterTools' /** * Stable context object passed to all middleware hooks. * Created once per chat() invocation and shared across all hooks. */ export interface ChatMiddlewareContext { /** Unique identifier for this chat request */ requestId: string /** Unique identifier for this stream */ streamId: string /** Conversation identifier, if provided by the caller */ conversationId?: string /** Current lifecycle phase */ phase: ChatMiddlewarePhase /** Current agent loop iteration (0-indexed) */ iteration: number /** Running count of chunks yielded so far */ chunkIndex: number /** Abort signal from the chat request */ signal?: AbortSignal /** Abort the chat run with a reason */ abort: (reason?: string) => void /** Opaque user-provided value from chat() options */ context: unknown /** * Defer a non-blocking side-effect promise. * Deferred promises do not block streaming and are awaited * after the terminal hook (onFinish/onAbort/onError). */ defer: (promise: Promise<unknown>) => void // --- Provider / adapter info (immutable for the lifetime of the request) --- /** Provider name (e.g., 'openai', 'anthropic') */ provider: string /** Model identifier (e.g., 'gpt-4o') */ model: string /** Source of the chat invocation — always 'server' for server-side chat */ source: 'client' | 'server' /** Whether the chat is streaming */ streaming: boolean // --- Config-derived info (may update per-iteration via onConfig) --- /** System prompts configured for this chat */ systemPrompts: Array<string> /** Names of configured tools, if any */ toolNames?: Array<string> /** Flattened generation options (temperature, topP, maxTokens, metadata) */ options?: Record<string, unknown> /** Provider-specific model options */ modelOptions?: Record<string, unknown> // --- Computed info --- /** Number of messages at the start of the request */ messageCount: number /** Whether tools are configured */ hasTools: boolean // --- Mutable per-iteration state --- /** Current assistant message ID (changes per iteration) */ currentMessageId: string | null /** Accumulated text content for the current iteration */ accumulatedContent: string // --- References --- /** Current messages array (read-only view) */ messages: ReadonlyArray<ModelMessage> /** Generate a unique ID with the given prefix */ createId: (prefix: string) => string } // =========================== // Config passed to onConfig // =========================== /** * Chat configuration that middleware can observe or transform. * This is a subset of the chat engine's effective configuration * that middleware is allowed to modify. */ export interface ChatMiddlewareConfig { messages: Array<ModelMessage> systemPrompts: Array<string> tools: Array<Tool> temperature?: number topP?: number maxTokens?: number metadata?: Record<string, unknown> modelOptions?: Record<string, unknown> } // =========================== // Tool Call Hook Context // =========================== /** * Context provided to tool call hooks (onBeforeToolCall / onAfterToolCall). */ export interface ToolCallHookContext { /** The tool call being executed */ toolCall: ToolCall /** The resolved tool definition, if found */ tool: Tool | undefined /** Parsed arguments for the tool call */ args: unknown /** Name of the tool */ toolName: string /** ID of the tool call */ toolCallId: string } /** * Decision returned from onBeforeToolCall. * - undefined/void: continue with normal execution * - { type: 'transformArgs', args }: replace args used for execution * - { type: 'skip', result }: skip execution, use provided result * - { type: 'abort', reason }: abort the entire chat run */ export type BeforeToolCallDecision = | void | undefined | null | { type: 'transformArgs'; args: unknown } | { type: 'skip'; result: unknown } | { type: 'abort'; reason?: string } /** * Outcome information provided to onAfterToolCall. */ export interface AfterToolCallInfo { /** The tool call that was executed */ toolCall: ToolCall /** The resolved tool definition */ tool: Tool | undefined /** Name of the tool */ toolName: string /** ID of the tool call */ toolCallId: string /** Whether the execution succeeded */ ok: boolean /** Duration of tool execution in milliseconds */ duration: number /** The result (if ok) or error (if not ok) */ result?: unknown error?: unknown } // =========================== // Iteration Info // =========================== /** * Information passed to onIteration at the start of each agent loop iteration. */ export interface IterationInfo { /** 0-based iteration index */ iteration: number /** The assistant message ID created for this iteration */ messageId: string } // =========================== // Tool Phase Complete Info // =========================== /** * Aggregate information passed to onToolPhaseComplete after all tool calls * in an iteration have been processed. */ export interface ToolPhaseCompleteInfo { /** Tool calls that were assigned to the assistant message */ toolCalls: Array<ToolCall> /** Completed tool results */ results: Array<{ toolCallId: string toolName: string result: unknown duration?: number }> /** Tools that need user approval */ needsApproval: Array<{ toolCallId: string toolName: string input: unknown approvalId: string }> /** Tools that need client-side execution */ needsClientExecution: Array<{ toolCallId: string toolName: string input: unknown }> } // =========================== // Usage Info // =========================== /** * Token usage statistics passed to the onUsage hook. * Extracted from the RUN_FINISHED chunk when usage data is present. */ export interface UsageInfo { promptTokens: number completionTokens: number totalTokens: number } // =========================== // Terminal Hook Info // =========================== /** * Information passed to onFinish. */ export interface FinishInfo { /** The finish reason from the last model response */ finishReason: string | null /** Total duration of the chat run in milliseconds */ duration: number /** Final accumulated text content */ content: string /** Final usage totals, if available */ usage?: { promptTokens: number completionTokens: number totalTokens: number } } /** * Information passed to onAbort. */ export interface AbortInfo { /** The reason for the abort, if provided */ reason?: string /** Duration until abort in milliseconds */ duration: number } /** * Information passed to onError. */ export interface ErrorInfo { /** The error that caused the failure */ error: unknown /** Duration until error in milliseconds */ duration: number } // =========================== // Middleware Interface // =========================== /** * Chat middleware interface. * * All hooks are optional. Middleware is composed in array order: * - `onConfig`: config piped through middlewares in order (first transform influences later) * - `onChunk`: each output chunk is fed into the next middleware in order * * @example Logging middleware * ```ts * const loggingMiddleware: ChatMiddleware = { * name: 'logging', * onStart(ctx) { console.log('Chat started', ctx.requestId) }, * onChunk(ctx, chunk) { console.log('Chunk:', chunk.type) }, * onFinish(ctx, info) { console.log('Done:', info.duration, 'ms') }, * } * ``` * * @example Redaction middleware * ```ts * const redactionMiddleware: ChatMiddleware = { * name: 'redaction', * onChunk(ctx, chunk) { * if (chunk.type === 'TEXT_MESSAGE_CONTENT') { * return { ...chunk, delta: redact(chunk.delta) } * } * }, * } * ``` */ export interface ChatMiddleware { /** Optional name for debugging and identification */ name?: string /** * Called to observe or transform the chat configuration. * Called at init and at the beginning of each agent iteration. * * Return a partial config to merge with the current config, or void to pass through. * Only the fields you return are overwritten — everything else is preserved. */ onConfig?: ( ctx: ChatMiddlewareContext, config: ChatMiddlewareConfig, ) => | void | null | Partial<ChatMiddlewareConfig> | Promise<void | Partial<ChatMiddlewareConfig>> /** * Called when the chat run starts (after initial onConfig). */ onStart?: (ctx: ChatMiddlewareContext) => void | Promise<void> /** * Called at the start of each agent loop iteration, after a new assistant message ID * is created. Use this to observe iteration boundaries. */ onIteration?: ( ctx: ChatMiddlewareContext, info: IterationInfo, ) => void | Promise<void> /** * Called for every chunk yielded by chat(). * Can observe, transform, expand, or drop chunks. * * @returns void (pass through), chunk (replace), chunk[] (expand), null (drop) */ onChunk?: ( ctx: ChatMiddlewareContext, chunk: StreamChunk, ) => | void | StreamChunk | Array<StreamChunk> | null | Promise<void | StreamChunk | Array<StreamChunk> | null> /** * Called before a tool is executed. * Can observe, transform args, skip execution, or abort the run. */ onBeforeToolCall?: ( ctx: ChatMiddlewareContext, hookCtx: ToolCallHookContext, ) => BeforeToolCallDecision | Promise<BeforeToolCallDecision> /** * Called after a tool execution completes (success or failure). */ onAfterToolCall?: ( ctx: ChatMiddlewareContext, info: AfterToolCallInfo, ) => void | Promise<void> /** * Called after all tool calls in an iteration have been processed. * Provides aggregate data about tool execution results, approvals, and client tools. */ onToolPhaseComplete?: ( ctx: ChatMiddlewareContext, info: ToolPhaseCompleteInfo, ) => void | Promise<void> /** * Called when usage data is available from a RUN_FINISHED chunk. * Called once per model iteration that reports usage. */ onUsage?: ( ctx: ChatMiddlewareContext, usage: UsageInfo, ) => void | Promise<void> /** * Called when the chat run completes normally. * Exactly one of onFinish/onAbort/onError will be called per run. */ onFinish?: ( ctx: ChatMiddlewareContext, info: FinishInfo, ) => void | Promise<void> /** * Called when the chat run is aborted. * Exactly one of onFinish/onAbort/onError will be called per run. */ onAbort?: ( ctx: ChatMiddlewareContext, info: AbortInfo, ) => void | Promise<void> /** * Called when the chat run encounters an unhandled error. * Exactly one of onFinish/onAbort/onError will be called per run. */ onError?: ( ctx: ChatMiddlewareContext, info: ErrorInfo, ) => void | Promise<void> }