@tanstack/ai
Version:
Core TanStack AI library - Open source AI SDK
393 lines (373 loc) • 11.5 kB
text/typescript
import { aiEventClient } from '@tanstack/ai-event-client'
import type { StreamChunk } from '../../../types'
import type {
AbortInfo,
AfterToolCallInfo,
BeforeToolCallDecision,
ChatMiddleware,
ChatMiddlewareConfig,
ChatMiddlewareContext,
ErrorInfo,
FinishInfo,
IterationInfo,
ToolCallHookContext,
ToolPhaseCompleteInfo,
UsageInfo,
} from './types'
/** Check if a middleware should be skipped for instrumentation events. */
function shouldSkipInstrumentation(mw: ChatMiddleware): boolean {
return mw.name === 'devtools'
}
/** Build the base context for middleware instrumentation events. */
function instrumentCtx(ctx: ChatMiddlewareContext) {
return {
requestId: ctx.requestId,
streamId: ctx.streamId,
clientId: ctx.conversationId,
timestamp: Date.now(),
}
}
/**
* Internal middleware runner that manages composed execution of middleware hooks.
* Created once per chat() invocation.
*/
export class MiddlewareRunner {
private readonly middlewares: ReadonlyArray<ChatMiddleware>
constructor(middlewares: ReadonlyArray<ChatMiddleware>) {
this.middlewares = middlewares
}
get hasMiddleware(): boolean {
return this.middlewares.length > 0
}
/**
* Pipe config through all middleware onConfig hooks in order.
* Each middleware receives the merged config from previous middleware.
* Partial returns are shallow-merged with the current config.
*/
async runOnConfig(
ctx: ChatMiddlewareContext,
config: ChatMiddlewareConfig,
): Promise<ChatMiddlewareConfig> {
let current = config
for (const mw of this.middlewares) {
if (mw.onConfig) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
const result = await mw.onConfig(ctx, current)
const hasTransform = result !== undefined && result !== null
if (hasTransform) {
current = { ...current, ...result }
}
if (!skip) {
const base = instrumentCtx(ctx)
aiEventClient.emit('middleware:hook:executed', {
...base,
middlewareName: mw.name || 'unnamed',
hookName: 'onConfig',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform,
})
if (hasTransform) {
aiEventClient.emit('middleware:config:transformed', {
...base,
middlewareName: mw.name || 'unnamed',
iteration: ctx.iteration,
changes: result as Record<string, unknown>,
})
}
}
}
}
return current
}
/**
* Call onStart on all middleware in order.
*/
async runOnStart(ctx: ChatMiddlewareContext): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onStart) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onStart(ctx)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onStart',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Pipe a single chunk through all middleware onChunk hooks in order.
* Returns the resulting chunks (0..N) to yield to the consumer.
*
* - void: pass through unchanged
* - chunk: replace with this chunk
* - chunk[]: expand to multiple chunks
* - null: drop the chunk entirely
*/
async runOnChunk(
ctx: ChatMiddlewareContext,
chunk: StreamChunk,
): Promise<Array<StreamChunk>> {
let chunks: Array<StreamChunk> = [chunk]
for (const mw of this.middlewares) {
if (!mw.onChunk) continue
const skip = shouldSkipInstrumentation(mw)
const nextChunks: Array<StreamChunk> = []
for (const c of chunks) {
const result = await mw.onChunk(ctx, c)
if (result === null) {
// Drop this chunk
if (!skip) {
aiEventClient.emit('middleware:chunk:transformed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
originalChunkType: c.type,
resultCount: 0,
wasDropped: true,
})
}
continue
} else if (result === undefined) {
// Pass through — no instrumentation for pass-throughs
nextChunks.push(c)
} else if (Array.isArray(result)) {
// Expand
nextChunks.push(...result)
if (!skip) {
aiEventClient.emit('middleware:chunk:transformed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
originalChunkType: c.type,
resultCount: result.length,
wasDropped: false,
})
}
} else {
// Replace
nextChunks.push(result)
if (!skip) {
aiEventClient.emit('middleware:chunk:transformed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
originalChunkType: c.type,
resultCount: 1,
wasDropped: false,
})
}
}
}
chunks = nextChunks
}
return chunks
}
/**
* Run onBeforeToolCall through middleware in order.
* Returns the first non-void decision, or undefined to continue normally.
*/
async runOnBeforeToolCall(
ctx: ChatMiddlewareContext,
hookCtx: ToolCallHookContext,
): Promise<BeforeToolCallDecision> {
for (const mw of this.middlewares) {
if (mw.onBeforeToolCall) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
const decision = await mw.onBeforeToolCall(ctx, hookCtx)
const hasTransform = decision !== undefined && decision !== null
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onBeforeToolCall',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform,
})
}
if (hasTransform) {
return decision
}
}
}
return undefined
}
/**
* Run onAfterToolCall on all middleware in order.
*/
async runOnAfterToolCall(
ctx: ChatMiddlewareContext,
info: AfterToolCallInfo,
): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onAfterToolCall) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onAfterToolCall(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onAfterToolCall',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onUsage on all middleware in order.
*/
async runOnUsage(
ctx: ChatMiddlewareContext,
usage: UsageInfo,
): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onUsage) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onUsage(ctx, usage)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onUsage',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onFinish on all middleware in order.
*/
async runOnFinish(
ctx: ChatMiddlewareContext,
info: FinishInfo,
): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onFinish) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onFinish(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onFinish',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onAbort on all middleware in order.
*/
async runOnAbort(ctx: ChatMiddlewareContext, info: AbortInfo): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onAbort) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onAbort(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onAbort',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onError on all middleware in order.
*/
async runOnError(ctx: ChatMiddlewareContext, info: ErrorInfo): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onError) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onError(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onError',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onIteration on all middleware in order.
* Called at the start of each agent loop iteration.
*/
async runOnIteration(
ctx: ChatMiddlewareContext,
info: IterationInfo,
): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onIteration) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onIteration(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onIteration',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
/**
* Run onToolPhaseComplete on all middleware in order.
* Called after all tool calls in an iteration have been processed.
*/
async runOnToolPhaseComplete(
ctx: ChatMiddlewareContext,
info: ToolPhaseCompleteInfo,
): Promise<void> {
for (const mw of this.middlewares) {
if (mw.onToolPhaseComplete) {
const skip = shouldSkipInstrumentation(mw)
const start = Date.now()
await mw.onToolPhaseComplete(ctx, info)
if (!skip) {
aiEventClient.emit('middleware:hook:executed', {
...instrumentCtx(ctx),
middlewareName: mw.name || 'unnamed',
hookName: 'onToolPhaseComplete',
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false,
})
}
}
}
}
}