@tanstack/ai
Version:
Type-safe TypeScript AI SDK for streaming chat, tool calling, agents, structured outputs, and multimodal generation.
511 lines (510 loc) • 17 kB
JavaScript
import { aiEventClient } from "@tanstack/ai-event-client";
function shouldSkipInstrumentation(mw) {
return mw.name === "devtools" || mw.name === "strip-to-spec";
}
function instrumentCtx(ctx) {
return {
requestId: ctx.requestId,
streamId: ctx.streamId,
clientId: ctx.threadId,
timestamp: Date.now()
};
}
class MiddlewareRunner {
middlewares;
logger;
constructor(middlewares, logger) {
this.middlewares = middlewares;
this.logger = logger;
}
get hasMiddleware() {
return this.middlewares.length > 0;
}
/**
* Pipe config through all middleware onConfig hooks in order.
* Each middleware receives the merged config from previous middleware.
* Partial returns are shallow-merged with the current config.
*/
async runOnConfig(ctx, config) {
let current = config;
for (const mw of this.middlewares) {
if (mw.onConfig) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
const result = await mw.onConfig(ctx, current);
const hasTransform = result !== void 0 && result !== null;
if (hasTransform) {
current = { ...current, ...result };
if (!skip) {
this.logger.config(
`middleware=${mw.name ?? "unnamed"} keys=${Object.keys(result).join(",")}`,
{
middleware: mw.name ?? "unnamed",
changes: result
}
);
}
}
if (!skip) {
const base = instrumentCtx(ctx);
aiEventClient.emit("middleware:hook:executed", {
...base,
middlewareName: mw.name || "unnamed",
hookName: "onConfig",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform
});
if (hasTransform) {
aiEventClient.emit("middleware:config:transformed", {
...base,
middlewareName: mw.name || "unnamed",
iteration: ctx.iteration,
changes: result
});
}
}
}
}
return current;
}
/**
* Pipe config through all middleware onStructuredOutputConfig hooks in order.
* Each middleware receives the merged config from previous middleware.
* Partial returns are shallow-merged with the current config.
*
* Called once at the structured-output boundary, before runOnConfig at the
* same boundary (which receives a ChatMiddlewareConfig view, no outputSchema).
*/
async runOnStructuredOutputConfig(ctx, config) {
let current = config;
for (const mw of this.middlewares) {
if (mw.onStructuredOutputConfig) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
const result = await mw.onStructuredOutputConfig(ctx, current);
const hasTransform = result !== void 0 && result !== null;
if (hasTransform) {
current = { ...current, ...result };
if (!skip) {
this.logger.config(
`middleware=${mw.name ?? "unnamed"} keys=${Object.keys(result).join(",")}`,
{
middleware: mw.name ?? "unnamed",
changes: result
}
);
}
}
if (!skip) {
const base = instrumentCtx(ctx);
aiEventClient.emit("middleware:hook:executed", {
...base,
middlewareName: mw.name || "unnamed",
hookName: "onStructuredOutputConfig",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform
});
if (hasTransform) {
aiEventClient.emit("middleware:config:transformed", {
...base,
middlewareName: mw.name || "unnamed",
iteration: ctx.iteration,
// `result` is `Partial<StructuredOutputMiddlewareConfig>` —
// Object.fromEntries(Object.entries(result)) yields the
// structural `Record<string, unknown>` the event emitter wants
// without an `as` cast.
changes: Object.fromEntries(Object.entries(result))
});
}
}
}
}
return current;
}
/**
* Run all `setup` hooks in array order, then assert every declared `provides`
* capability was actually provided. Wires the last-wins duplicate-provide
* warning into the registry. Runs before init `onConfig`.
*
* Takes the full `ChatMiddlewareContext` — the same stable context the engine
* threads through every other hook — because it both forwards `ctx` to each
* `setup` hook and emits instrumentation events from it.
*/
async runSetup(ctx) {
ctx.capabilities.setOnDuplicate((name) => {
this.logger.warn(
`capability "${name}" was provided more than once; last provider wins`,
{ capability: name }
);
});
for (const mw of this.middlewares) {
if (mw.setup) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.setup(ctx);
if (!skip) {
this.logger.middleware(
`hook=setup middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "setup" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "setup",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
for (const mw of this.middlewares) {
for (const handle of mw.provides ?? []) {
if (!ctx.capabilities.has(handle)) {
throw new Error(
`Middleware "${mw.name ?? "unnamed"}" declares it provides "${handle.capabilityName}" but never called provide() in setup().`
);
}
}
}
}
/**
* Call onStart on all middleware in order.
*/
async runOnStart(ctx) {
for (const mw of this.middlewares) {
if (mw.onStart) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onStart(ctx);
if (!skip) {
this.logger.middleware(
`hook=onStart middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onStart" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onStart",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Pipe a single chunk through all middleware onChunk hooks in order.
* Returns the resulting chunks (0..N) to yield to the consumer.
*
* - void: pass through unchanged
* - chunk: replace with this chunk
* - chunk[]: expand to multiple chunks
* - null: drop the chunk entirely
*/
async runOnChunk(ctx, chunk) {
let chunks = [chunk];
for (const mw of this.middlewares) {
if (!mw.onChunk) continue;
const skip = shouldSkipInstrumentation(mw);
const nextChunks = [];
for (const c of chunks) {
const chunkType = c.type;
if (!skip) {
this.logger.middleware(
`hook=onChunk middleware=${mw.name ?? "unnamed"} in=${chunkType}`,
{ middleware: mw.name ?? "unnamed", hook: "onChunk", in: c }
);
}
const result = await mw.onChunk(ctx, c);
if (result === null) {
if (!skip) {
this.logger.middleware(
`hook=onChunk middleware=${mw.name ?? "unnamed"} in=${chunkType} out=<dropped>`,
{
middleware: mw.name ?? "unnamed",
hook: "onChunk",
dropped: true
}
);
aiEventClient.emit("middleware:chunk:transformed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
originalChunkType: chunkType,
resultCount: 0,
wasDropped: true
});
}
continue;
} else if (result === void 0) {
nextChunks.push(c);
} else if (Array.isArray(result)) {
nextChunks.push(...result);
if (!skip) {
this.logger.middleware(
`hook=onChunk middleware=${mw.name ?? "unnamed"} in=${chunkType} out=[${result.map((r) => r.type).join(",")}]`,
{
middleware: mw.name ?? "unnamed",
hook: "onChunk",
in: c,
out: result
}
);
aiEventClient.emit("middleware:chunk:transformed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
originalChunkType: chunkType,
resultCount: result.length,
wasDropped: false
});
}
} else {
nextChunks.push(result);
if (!skip) {
this.logger.middleware(
`hook=onChunk middleware=${mw.name ?? "unnamed"} in=${chunkType} out=${result.type}`,
{
middleware: mw.name ?? "unnamed",
hook: "onChunk",
in: c,
out: result
}
);
aiEventClient.emit("middleware:chunk:transformed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
originalChunkType: chunkType,
resultCount: 1,
wasDropped: false
});
}
}
}
chunks = nextChunks;
}
return chunks;
}
/**
* Run onBeforeToolCall through middleware in order.
* Returns the first non-void decision, or undefined to continue normally.
*/
async runOnBeforeToolCall(ctx, hookCtx) {
for (const mw of this.middlewares) {
if (mw.onBeforeToolCall) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
const decision = await mw.onBeforeToolCall(ctx, hookCtx);
const hasTransform = decision !== void 0 && decision !== null;
if (!skip) {
this.logger.middleware(
`hook=onBeforeToolCall middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onBeforeToolCall" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onBeforeToolCall",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform
});
}
if (hasTransform) {
return decision;
}
}
}
return void 0;
}
/**
* Run onAfterToolCall on all middleware in order.
*/
async runOnAfterToolCall(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onAfterToolCall) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onAfterToolCall(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onAfterToolCall middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onAfterToolCall" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onAfterToolCall",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onUsage on all middleware in order.
*/
async runOnUsage(ctx, usage) {
for (const mw of this.middlewares) {
if (mw.onUsage) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onUsage(ctx, usage);
if (!skip) {
this.logger.middleware(
`hook=onUsage middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onUsage" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onUsage",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onFinish on all middleware in order.
*/
async runOnFinish(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onFinish) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onFinish(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onFinish middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onFinish" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onFinish",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onAbort on all middleware in order.
*/
async runOnAbort(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onAbort) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onAbort(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onAbort middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onAbort" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onAbort",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onError on all middleware in order.
*/
async runOnError(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onError) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onError(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onError middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onError" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onError",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onIteration on all middleware in order.
* Called at the start of each agent loop iteration.
*/
async runOnIteration(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onIteration) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onIteration(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onIteration middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onIteration" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onIteration",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
/**
* Run onToolPhaseComplete on all middleware in order.
* Called after all tool calls in an iteration have been processed.
*/
async runOnToolPhaseComplete(ctx, info) {
for (const mw of this.middlewares) {
if (mw.onToolPhaseComplete) {
const skip = shouldSkipInstrumentation(mw);
const start = Date.now();
await mw.onToolPhaseComplete(ctx, info);
if (!skip) {
this.logger.middleware(
`hook=onToolPhaseComplete middleware=${mw.name ?? "unnamed"}`,
{ middleware: mw.name ?? "unnamed", hook: "onToolPhaseComplete" }
);
aiEventClient.emit("middleware:hook:executed", {
...instrumentCtx(ctx),
middlewareName: mw.name || "unnamed",
hookName: "onToolPhaseComplete",
iteration: ctx.iteration,
duration: Date.now() - start,
hasTransform: false
});
}
}
}
}
}
export {
MiddlewareRunner
};
//# sourceMappingURL=compose.js.map