UNPKG

@tanstack/ai

Version:

Core TanStack AI library - Open source AI SDK

393 lines (373 loc) 11.5 kB
import { aiEventClient } from '@tanstack/ai-event-client' import type { StreamChunk } from '../../../types' import type { AbortInfo, AfterToolCallInfo, BeforeToolCallDecision, ChatMiddleware, ChatMiddlewareConfig, ChatMiddlewareContext, ErrorInfo, FinishInfo, IterationInfo, ToolCallHookContext, ToolPhaseCompleteInfo, UsageInfo, } from './types' /** Check if a middleware should be skipped for instrumentation events. */ function shouldSkipInstrumentation(mw: ChatMiddleware): boolean { return mw.name === 'devtools' } /** Build the base context for middleware instrumentation events. */ function instrumentCtx(ctx: ChatMiddlewareContext) { return { requestId: ctx.requestId, streamId: ctx.streamId, clientId: ctx.conversationId, timestamp: Date.now(), } } /** * Internal middleware runner that manages composed execution of middleware hooks. * Created once per chat() invocation. */ export class MiddlewareRunner { private readonly middlewares: ReadonlyArray<ChatMiddleware> constructor(middlewares: ReadonlyArray<ChatMiddleware>) { this.middlewares = middlewares } get hasMiddleware(): boolean { return this.middlewares.length > 0 } /** * Pipe config through all middleware onConfig hooks in order. * Each middleware receives the merged config from previous middleware. * Partial returns are shallow-merged with the current config. */ async runOnConfig( ctx: ChatMiddlewareContext, config: ChatMiddlewareConfig, ): Promise<ChatMiddlewareConfig> { let current = config for (const mw of this.middlewares) { if (mw.onConfig) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() const result = await mw.onConfig(ctx, current) const hasTransform = result !== undefined && result !== null if (hasTransform) { current = { ...current, ...result } } if (!skip) { const base = instrumentCtx(ctx) aiEventClient.emit('middleware:hook:executed', { ...base, middlewareName: mw.name || 'unnamed', hookName: 'onConfig', iteration: ctx.iteration, duration: Date.now() - start, hasTransform, }) if (hasTransform) { aiEventClient.emit('middleware:config:transformed', { ...base, middlewareName: mw.name || 'unnamed', iteration: ctx.iteration, changes: result as Record<string, unknown>, }) } } } } return current } /** * Call onStart on all middleware in order. */ async runOnStart(ctx: ChatMiddlewareContext): Promise<void> { for (const mw of this.middlewares) { if (mw.onStart) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onStart(ctx) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onStart', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Pipe a single chunk through all middleware onChunk hooks in order. * Returns the resulting chunks (0..N) to yield to the consumer. * * - void: pass through unchanged * - chunk: replace with this chunk * - chunk[]: expand to multiple chunks * - null: drop the chunk entirely */ async runOnChunk( ctx: ChatMiddlewareContext, chunk: StreamChunk, ): Promise<Array<StreamChunk>> { let chunks: Array<StreamChunk> = [chunk] for (const mw of this.middlewares) { if (!mw.onChunk) continue const skip = shouldSkipInstrumentation(mw) const nextChunks: Array<StreamChunk> = [] for (const c of chunks) { const result = await mw.onChunk(ctx, c) if (result === null) { // Drop this chunk if (!skip) { aiEventClient.emit('middleware:chunk:transformed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', originalChunkType: c.type, resultCount: 0, wasDropped: true, }) } continue } else if (result === undefined) { // Pass through — no instrumentation for pass-throughs nextChunks.push(c) } else if (Array.isArray(result)) { // Expand nextChunks.push(...result) if (!skip) { aiEventClient.emit('middleware:chunk:transformed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', originalChunkType: c.type, resultCount: result.length, wasDropped: false, }) } } else { // Replace nextChunks.push(result) if (!skip) { aiEventClient.emit('middleware:chunk:transformed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', originalChunkType: c.type, resultCount: 1, wasDropped: false, }) } } } chunks = nextChunks } return chunks } /** * Run onBeforeToolCall through middleware in order. * Returns the first non-void decision, or undefined to continue normally. */ async runOnBeforeToolCall( ctx: ChatMiddlewareContext, hookCtx: ToolCallHookContext, ): Promise<BeforeToolCallDecision> { for (const mw of this.middlewares) { if (mw.onBeforeToolCall) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() const decision = await mw.onBeforeToolCall(ctx, hookCtx) const hasTransform = decision !== undefined && decision !== null if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onBeforeToolCall', iteration: ctx.iteration, duration: Date.now() - start, hasTransform, }) } if (hasTransform) { return decision } } } return undefined } /** * Run onAfterToolCall on all middleware in order. */ async runOnAfterToolCall( ctx: ChatMiddlewareContext, info: AfterToolCallInfo, ): Promise<void> { for (const mw of this.middlewares) { if (mw.onAfterToolCall) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onAfterToolCall(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onAfterToolCall', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onUsage on all middleware in order. */ async runOnUsage( ctx: ChatMiddlewareContext, usage: UsageInfo, ): Promise<void> { for (const mw of this.middlewares) { if (mw.onUsage) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onUsage(ctx, usage) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onUsage', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onFinish on all middleware in order. */ async runOnFinish( ctx: ChatMiddlewareContext, info: FinishInfo, ): Promise<void> { for (const mw of this.middlewares) { if (mw.onFinish) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onFinish(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onFinish', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onAbort on all middleware in order. */ async runOnAbort(ctx: ChatMiddlewareContext, info: AbortInfo): Promise<void> { for (const mw of this.middlewares) { if (mw.onAbort) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onAbort(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onAbort', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onError on all middleware in order. */ async runOnError(ctx: ChatMiddlewareContext, info: ErrorInfo): Promise<void> { for (const mw of this.middlewares) { if (mw.onError) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onError(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onError', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onIteration on all middleware in order. * Called at the start of each agent loop iteration. */ async runOnIteration( ctx: ChatMiddlewareContext, info: IterationInfo, ): Promise<void> { for (const mw of this.middlewares) { if (mw.onIteration) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onIteration(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onIteration', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } /** * Run onToolPhaseComplete on all middleware in order. * Called after all tool calls in an iteration have been processed. */ async runOnToolPhaseComplete( ctx: ChatMiddlewareContext, info: ToolPhaseCompleteInfo, ): Promise<void> { for (const mw of this.middlewares) { if (mw.onToolPhaseComplete) { const skip = shouldSkipInstrumentation(mw) const start = Date.now() await mw.onToolPhaseComplete(ctx, info) if (!skip) { aiEventClient.emit('middleware:hook:executed', { ...instrumentCtx(ctx), middlewareName: mw.name || 'unnamed', hookName: 'onToolPhaseComplete', iteration: ctx.iteration, duration: Date.now() - start, hasTransform: false, }) } } } } }