@genkit-ai/ai
Version:
Genkit AI framework generative AI APIs.
1,126 lines (1,011 loc) • 31.3 kB
text/typescript
/**
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { z } from '@genkit-ai/core';
import { initNodeFeatures } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
import {
GenerateResponseChunk,
generate,
generateStream,
} from '../../src/generate.js';
import {
GenerateMiddlewareDef,
generateMiddleware,
} from '../../src/generate/middleware.js';
import { resolveRestartedTools } from '../../src/generate/resolve-tool-requests.js';
import { defineModel } from '../../src/model.js';
import { ToolInterruptError, defineTool, tool } from '../../src/tool.js';
initNodeFeatures();
describe('generateMiddleware', () => {
let registry: Registry;
beforeEach(() => {
registry = new Registry();
});
it('runs generate and model middleware in the correct order', async () => {
const executionOrder: string[] = [];
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async (req) => {
executionOrder.push('modelExecution');
return {
message: {
role: 'model',
content: [{ text: 'response' }],
},
};
}
);
const testMiddleware = generateMiddleware(
{ name: 'testMiddleware' },
() => ({
generate: async (req, ctx, next) => {
executionOrder.push('generateBefore');
const res = await next(req, ctx);
executionOrder.push('generateAfter');
return res;
},
model: async (req, ctx, next) => {
executionOrder.push('modelBefore');
const res = await next(req, ctx);
executionOrder.push('modelAfter');
return res;
},
})
);
await generate(registry, {
model: mockModel,
prompt: 'hi',
use: [testMiddleware()],
});
assert.deepStrictEqual(executionOrder, [
'generateBefore',
'modelBefore',
'modelExecution',
'modelAfter',
'generateAfter',
]);
});
it('runs tool middleware correctly', async () => {
const executionOrder: string[] = [];
const mockTool = defineTool(
registry,
{
name: 'mockTool',
description: 'A mock tool',
inputSchema: z.object({}),
outputSchema: z.string(),
},
async () => {
executionOrder.push('toolExecution');
return 'tool output';
}
);
let turns = 0;
const mockModel = defineModel(
registry,
{ name: 'mockModelWithTool' },
async (req) => {
executionOrder.push('modelExecution');
turns++;
if (turns === 1) {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: mockTool.__action.name,
ref: '123',
input: {},
},
},
],
},
};
} else {
return {
message: {
role: 'model',
content: [{ text: 'final response' }],
},
};
}
}
);
const testMiddleware = generateMiddleware(
{ name: 'testMiddleware' },
() => {
let turnCount = 0;
return {
generate: async (req, ctx, next) => {
const t = ++turnCount;
executionOrder.push('generateBefore-' + t);
const res = await next(req, ctx);
executionOrder.push('generateAfter-' + t);
return res;
},
model: async (req, ctx, next) => {
executionOrder.push(`modelBefore-${turnCount}`);
const res = await next(req, ctx);
executionOrder.push(`modelAfter-${turnCount}`);
return res;
},
tool: async (req, ctx, next) => {
executionOrder.push(`toolBefore-${turnCount}`);
const res = await next(req, ctx);
executionOrder.push(`toolAfter-${turnCount}`);
return res;
},
};
}
);
await generate(registry, {
model: mockModel,
tools: [mockTool],
prompt: 'hi',
use: [testMiddleware()],
});
assert.deepStrictEqual(executionOrder, [
'generateBefore-1',
'modelBefore-1', // Turn 1
'modelExecution',
'modelAfter-1',
'toolBefore-1', // Tool execution
'toolExecution',
'toolAfter-1',
'generateBefore-2',
'modelBefore-2', // Turn 2
'modelExecution',
'modelAfter-2',
'generateAfter-2',
'generateAfter-1',
]);
});
it('supports configuration and old-style function middleware', async () => {
let configValue = '';
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async (req) => {
return {
message: {
role: 'model',
content: [{ text: 'response' }],
},
};
}
);
const testMiddleware = generateMiddleware(
{ name: 'configMw', configSchema: z.object({ val: z.string() }) },
(options) => ({
model: async (req, ctx, next) => {
configValue = options.config?.val || '';
return next(req, ctx);
},
})
);
let oldStyleExecuted = false;
const oldStyleMiddleware = async (req: any, next: any) => {
oldStyleExecuted = true;
return next(req);
};
await generate(registry, {
model: mockModel,
prompt: 'test',
use: [testMiddleware({ val: 'test_config' }), oldStyleMiddleware],
});
assert.strictEqual(configValue, 'test_config');
assert.strictEqual(oldStyleExecuted, true);
});
it('supports pre-registered middleware (e.g. installed via plugin)', async () => {
let executed = false;
let configValue = '';
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async (req) => {
return {
message: {
role: 'model',
content: [{ text: 'response' }],
},
};
}
);
const preRegisteredMw = generateMiddleware<{ pluginOption: string }>(
{ name: 'preRegisteredMw', configSchema: z.object({ val: z.string() }) },
(middlewareOpts) => ({
model: async (req, ctx, next) => {
executed = true;
configValue = middlewareOpts.pluginConfig?.pluginOption || '';
return next(req, ctx);
},
})
);
// Act as a plugin registering the middleware
const myPlugin = preRegisteredMw.plugin({ pluginOption: 'plugin_config' });
assert.ok(myPlugin.middleware);
assert.deepStrictEqual((myPlugin.middleware()[0] as any).pluginOptions, {
pluginOption: 'plugin_config',
});
const middlewares = myPlugin.middleware();
assert.strictEqual(middlewares.length, 1);
const mw = middlewares[0];
// Verify name and properties are correctly copied/preserved
assert.strictEqual(mw.name, 'preRegisteredMw');
assert.strictEqual(mw.configSchema, preRegisteredMw.configSchema);
assert.strictEqual(mw.toJson, preRegisteredMw.toJson);
middlewares.forEach((mw: any) => {
registry.registerValue('middleware', mw.name, mw);
});
await generate(registry, {
model: mockModel,
prompt: 'test',
use: [{ name: 'preRegisteredMw' }],
});
assert.strictEqual(executed, true);
assert.strictEqual(configValue, 'plugin_config');
});
it('should resolve tools injected by middleware during restarts', async () => {
const middlewareTool = defineTool(
registry,
{
name: 'middlewareTool',
description: 'injected by middleware',
inputSchema: z.object({}),
outputSchema: z.string(),
},
async () => 'success'
);
const middleware: GenerateMiddlewareDef = {
tools: [middlewareTool],
};
const rawRequest = {
tools: [],
messages: [
{
role: 'model',
content: [
{
toolRequest: { name: 'middlewareTool', input: {} },
metadata: { resumed: true },
},
],
},
],
} as any;
const result = await resolveRestartedTools(registry, rawRequest, [
middleware,
]);
assert.strictEqual(result.length, 1);
assert.deepStrictEqual(result[0].metadata?.pendingOutput, 'success');
});
it('throws an error if a middleware factory is passed without being called', async () => {
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async () => ({
message: { role: 'model', content: [{ text: 'done' }] },
finishReason: 'stop',
})
);
const streamModifyingMw = generateMiddleware({ name: 'dummy' }, () => ({}));
await assert.rejects(
async () => {
await generate(registry, {
model: mockModel,
prompt: 'test',
use: [streamModifyingMw as any],
});
},
(err: any) => {
assert.strictEqual(err.name, 'GenkitError');
assert.match(err.message, /must be called with \(\)/);
return true;
}
);
});
it('can intercept and modify the stream from model and generate interceptors', async () => {
const chunkIntercepts: string[] = [];
const mockStreamingModel = defineModel(
registry,
{ name: 'mockStreamingModel' },
async (req, streamingCallback) => {
if (streamingCallback) {
streamingCallback({ content: [{ text: 'chunk1' }] });
streamingCallback({ content: [{ text: 'chunk2' }] });
}
return {
message: {
role: 'model',
content: [{ text: 'chunk1chunk2' }],
},
finishReason: 'stop',
};
}
);
const streamModifyingMw = generateMiddleware(
{ name: 'streamModifier' },
() => ({
model: async (req, ctx, next) => {
const originalOnChunk = ctx.onChunk;
let interceptedCtx = ctx;
if (originalOnChunk) {
interceptedCtx = {
...ctx,
onChunk: (chunk) => {
chunkIntercepts.push(`model_mw: ${chunk.content[0].text}`);
chunk.content[0].text = chunk.content[0].text?.toUpperCase();
originalOnChunk(chunk);
},
};
}
return next(req, interceptedCtx);
},
generate: async (req, ctx, next) => {
const originalOnChunk = ctx.onChunk;
let interceptedCtx = ctx;
if (originalOnChunk) {
interceptedCtx = {
...ctx,
onChunk: (chunk) => {
chunkIntercepts.push(`gen_mw: ${chunk.content[0].text}`);
chunk.content[0].text = `[${chunk.content[0].text}]`;
originalOnChunk(chunk);
},
};
}
const res = await next(req, interceptedCtx);
if (res.message) {
return {
...res,
message: {
...res.message,
content: [
{ text: `modified_result: ${res.message.content[0].text}` },
],
},
};
}
return res;
},
})
);
let finalChunks: string[] = [];
const { response, stream } = generateStream(registry, {
model: mockStreamingModel,
prompt: 'test streaming mw',
use: [streamModifyingMw()],
});
for await (const chunk of stream) {
finalChunks.push(chunk.text);
}
const res = await response;
assert.deepStrictEqual(chunkIntercepts, [
'model_mw: chunk1',
'gen_mw: CHUNK1',
'model_mw: chunk2',
'gen_mw: CHUNK2',
]);
assert.deepStrictEqual(finalChunks, ['[CHUNK1]', '[CHUNK2]']);
assert.strictEqual(res.text, 'modified_result: chunk1chunk2');
});
it('executes multiple middleware in the correct order', async () => {
const executionOrder: string[] = [];
const mw1 = generateMiddleware({ name: 'mw1' }, () => ({
async generate(opts, ctx, next) {
executionOrder.push('mw1:gen:start');
const res = await next(opts, ctx);
executionOrder.push('mw1:gen:end');
return res;
},
async model(req, ctx, next) {
executionOrder.push('mw1:model:start');
const res = await next(req, ctx);
executionOrder.push('mw1:model:end');
return res;
},
}));
const mw2 = generateMiddleware({ name: 'mw2' }, () => ({
async generate(opts, ctx, next) {
executionOrder.push('mw2:gen:start');
const res = await next(opts, ctx);
executionOrder.push('mw2:gen:end');
return res;
},
async model(req, ctx, next) {
executionOrder.push('mw2:model:start');
const res = await next(req, ctx);
executionOrder.push('mw2:model:end');
return res;
},
}));
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async () => ({
message: { role: 'model', content: [{ text: 'done' }] },
finishReason: 'stop',
})
);
await generate(registry, {
model: mockModel,
prompt: 'test multiple',
use: [mw1(), mw2()],
});
// The entire 'generate' layer runs before we ever descend to the 'model' level
assert.deepStrictEqual(executionOrder, [
'mw1:gen:start',
'mw2:gen:start',
'mw1:model:start',
'mw2:model:start',
'mw2:model:end',
'mw1:model:end',
'mw2:gen:end',
'mw1:gen:end',
]);
});
it('supports a combination of new middleware and old-style functional middleware', async () => {
const executionOrder: string[] = [];
const newMw = generateMiddleware({ name: 'newMw' }, () => ({
async generate(opts, ctx, next) {
executionOrder.push('newMw:gen:start');
const res = await next(opts, ctx);
executionOrder.push('newMw:gen:end');
return res;
},
async model(req, ctx, next) {
executionOrder.push('newMw:model:start');
const res = await next(req, ctx);
executionOrder.push('newMw:model:end');
return res;
},
}));
const oldMw1 = async (req: any, next: any) => {
executionOrder.push('oldMw1:model:start');
const res = await next(); // Validating 0-argument backwards-compatibility
executionOrder.push('oldMw1:model:end');
return res;
};
const oldMw2 = async (req: any, ctx: any, next: any) => {
executionOrder.push('oldMw2:model:start');
const res = await next(req, ctx);
executionOrder.push('oldMw2:model:end');
return res;
};
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async () => ({
message: { role: 'model', content: [{ text: 'done' }] },
finishReason: 'stop',
})
);
await generate(registry, {
model: mockModel,
prompt: 'test mixed',
use: [oldMw1, newMw(), oldMw2],
});
assert.deepStrictEqual(executionOrder, [
'newMw:gen:start', // Generate level ALWAYS runs first across full array
'oldMw1:model:start',
'newMw:model:start',
'oldMw2:model:start',
'oldMw2:model:end',
'newMw:model:end',
'oldMw1:model:end',
'newMw:gen:end',
]);
});
it('injects tools from new-style generateMiddleware and executes tool requests', async () => {
let toolExecutionCount = 0;
const injectedTool = tool(
{
name: 'injectedTool',
description: 'injected tool description',
inputSchema: z.object({ arg: z.string() }),
outputSchema: z.string(),
},
async (input) => {
toolExecutionCount++;
return `Result: ${input.arg}`;
}
);
const toolMiddleware = generateMiddleware({ name: 'toolMw' }, () => ({
tools: [injectedTool],
}));
let callCount = 0;
const mockToolModel = defineModel(
registry,
{ name: 'mockToolModel' },
async (req) => {
callCount++;
// Assert that the tools sent to the model include the injected tool
assert.ok(req.tools?.find((t) => t.name === 'injectedTool'));
if (callCount === 1) {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: 'injectedTool',
ref: 'call_1',
input: { arg: 'hello' },
},
},
],
},
finishReason: 'stop',
};
} else {
assert.strictEqual(req.messages[2].role, 'tool');
const toolData = req.messages[2].content[0].toolResponse;
assert.strictEqual(toolData?.name, 'injectedTool');
assert.strictEqual(toolData?.output, 'Result: hello');
return {
message: { role: 'model', content: [{ text: 'final response' }] },
finishReason: 'stop',
};
}
}
);
const result = await generate(registry, {
model: mockToolModel,
prompt: 'test tools',
use: [toolMiddleware()],
});
assert.strictEqual(result.text, 'final response');
assert.strictEqual(toolExecutionCount, 1);
});
it('handles ToolInterruptError from middleware', async () => {
const mockTool = defineTool(
registry,
{
name: 'interruptTool',
description: 'interrupts',
inputSchema: z.object({}),
outputSchema: z.string(),
},
async () => {
return 'foo';
}
);
const interruptMiddleware = generateMiddleware(
{ name: 'interruptMw' },
() => ({
tool: async (req, ctx, next) => {
throw new ToolInterruptError({ some: 'metadata' });
},
})
);
const mockModel = defineModel(
registry,
{ name: 'mockModelWithTool' },
async (req) => {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: mockTool.__action.name,
ref: '123',
input: {},
},
},
],
},
};
}
);
const result = await generate(registry, {
model: mockModel,
prompt: 'hi',
tools: ['interruptTool'],
use: [interruptMiddleware()],
});
assert.strictEqual(result.finishReason, 'interrupted');
const interruptPart = result.message?.content.find(
(p) => p.metadata?.interrupt
);
assert.ok(interruptPart);
assert.deepStrictEqual(interruptPart.metadata?.interrupt, {
some: 'metadata',
});
});
it('resumes tool execution with modified metadata after interrupt', async () => {
const mockTool = defineTool(
registry,
{
name: 'interruptTool',
description: 'interrupts',
inputSchema: z.object({}),
outputSchema: z.string(),
},
async () => {
return 'tool output';
}
);
let middlewareRunCount = 0;
const interruptMiddleware = generateMiddleware(
{ name: 'interruptMw' },
() => ({
tool: async (req, ctx, next) => {
middlewareRunCount++;
if (req.metadata?.['approved'] === true) {
return next(req, ctx);
}
throw new ToolInterruptError({ some: 'metadata' });
},
})
);
let callCount = 0;
const mockModel = defineModel(
registry,
{ name: 'mockModelWithTool' },
async (req) => {
callCount++;
if (callCount === 1) {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: mockTool.__action.name,
ref: '123',
input: {},
},
},
],
},
};
} else {
return {
message: {
role: 'model',
content: [{ text: 'final response' }],
},
};
}
}
);
const result = await generate(registry, {
model: mockModel,
prompt: 'hi',
tools: ['interruptTool'],
use: [interruptMiddleware()],
});
assert.strictEqual(result.finishReason, 'interrupted');
const interruptPart = result.interrupts[0];
assert.ok(interruptPart);
assert.strictEqual(middlewareRunCount, 1);
// Modify metadata
if (interruptPart.metadata) {
interruptPart.metadata = { ...interruptPart.metadata, approved: true };
}
const result2 = await generate(registry, {
model: mockModel,
messages: result.messages,
tools: ['interruptTool'],
use: [interruptMiddleware()],
resume: {
restart: [interruptPart],
},
});
assert.strictEqual(result2.text, 'final response');
// Middleware should have run again
assert.strictEqual(middlewareRunCount, 2);
});
it('re-runs generate middleware after resuming tool execution', async () => {
const mockTool = defineTool(
registry,
{
name: 'interruptTool',
description: 'interrupts',
inputSchema: z.object({}),
outputSchema: z.string(),
},
async () => {
return 'tool output';
}
);
let generateMiddlewareCallCount = 0;
let seenToolResponseInGenerate = false;
const testMiddleware = generateMiddleware({ name: 'testMw' }, () => ({
generate: async (req, ctx, next) => {
generateMiddlewareCallCount++;
const lastMsg = req.request.messages[req.request.messages.length - 1];
if (lastMsg?.role === 'tool') {
seenToolResponseInGenerate = true;
}
return next(req, ctx);
},
tool: async (req, ctx, next) => {
if (req.metadata?.['approved'] === true) {
return next(req, ctx);
}
throw new ToolInterruptError({ some: 'metadata' });
},
}));
let callCount = 0;
const mockModel = defineModel(
registry,
{ name: 'mockModelWithTool2' },
async (req) => {
callCount++;
if (callCount === 1) {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: mockTool.__action.name,
ref: '123',
input: {},
},
},
],
},
};
} else {
return {
message: {
role: 'model',
content: [{ text: 'final response' }],
},
};
}
}
);
const result = await generate(registry, {
model: mockModel,
prompt: 'hi',
tools: ['interruptTool'],
use: [testMiddleware()],
});
assert.strictEqual(result.finishReason, 'interrupted');
const interruptPart = result.interrupts[0];
assert.ok(interruptPart);
// Modify metadata
if (interruptPart.metadata) {
interruptPart.metadata = { ...interruptPart.metadata, approved: true };
}
generateMiddlewareCallCount = 0; // Reset
seenToolResponseInGenerate = false;
await generate(registry, {
model: mockModel,
messages: result.messages,
tools: ['interruptTool'],
use: [testMiddleware()],
resume: {
restart: [interruptPart],
},
});
assert.ok(
seenToolResponseInGenerate,
'Generate middleware should see the tool response'
);
assert.strictEqual(generateMiddlewareCallCount, 2);
});
it('should handle tool middleware returning undefined', async () => {
const mockTool = tool(
{
name: 'mockTool',
description: 'a mock tool',
inputSchema: z.object({}),
},
async () => 'tool response'
);
const mockModel = defineModel(
registry,
{ name: 'mockModelWithTool3' },
async (req) => {
if (req.messages.length === 1) {
return {
message: {
role: 'model',
content: [
{
toolRequest: {
name: mockTool.__action.name,
ref: '123',
input: {},
},
},
],
},
};
}
return { message: { role: 'model', content: [{ text: 'done' }] } };
}
);
const testMiddleware = generateMiddleware(
{ name: 'swallowToolMw' },
() => ({
tool: async (req, ctx, next) => {
return undefined; // Swallowing the tool call
},
})
);
const result = await generate(registry, {
model: mockModel,
prompt: 'hi',
tools: [mockTool],
use: [testMiddleware()],
});
// Verify it doesn't crash and completes.
assert.strictEqual(result.text, 'done');
// We expect 3 messages:
// 1. User: "hi" (the prompt)
// 2. Model: toolRequest (from Turn 1)
// 3. Model: "done" (from Turn 2)
// There should be NO 'tool' role message in between because the middleware swallowed it!
assert.strictEqual(result.messages.length, 3);
assert.strictEqual(result.messages[0].role, 'user');
assert.strictEqual(result.messages[1].role, 'model');
assert.strictEqual(result.messages[2].role, 'model'); // Consecutive model message!
// Ensure no tool response parts exist
const hasToolResponse = result.messages.some((m) =>
m.content.some((c) => c.toolResponse)
);
assert.ok(!hasToolResponse, 'Should not contain any tool response');
});
it('passes and respects envelope updates in generate middleware', async () => {
let receivedIndex = -1;
let receivedTurn = -1;
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async () => ({
message: { role: 'model', content: [{ text: 'done' }] },
})
);
const testMiddleware = generateMiddleware({ name: 'test' }, () => ({
generate: async (envelope, ctx, next) => {
receivedIndex = envelope.messageIndex;
receivedTurn = envelope.currentTurn;
// Increment messageIndex by 5 and currentTurn by 2
return next(
{
...envelope,
messageIndex: envelope.messageIndex + 5,
currentTurn: envelope.currentTurn + 2,
},
ctx
);
},
}));
let checkIndex = -1;
let checkTurn = -1;
const checkerMiddleware = generateMiddleware({ name: 'checker' }, () => ({
generate: async (envelope, ctx, next) => {
checkIndex = envelope.messageIndex;
checkTurn = envelope.currentTurn;
return next(envelope, ctx);
},
}));
await generate(registry, {
model: mockModel,
prompt: 'hi',
use: [testMiddleware(), checkerMiddleware()],
});
assert.strictEqual(receivedIndex, 0, 'Initial messageIndex should be 0');
assert.strictEqual(receivedTurn, 0, 'Initial currentTurn should be 0');
assert.strictEqual(
checkIndex,
5,
'Checker should see incremented messageIndex'
);
assert.strictEqual(
checkTurn,
2,
'Checker should see incremented currentTurn'
);
});
it('wraps raw chunks from middleware in GenerateResponseChunk', async () => {
const mockModel = defineModel(
registry,
{ name: 'mockModel' },
async () => ({
message: { role: 'model', content: [{ text: 'done' }] },
})
);
const rawChunkMiddleware = generateMiddleware({ name: 'rawChunk' }, () => ({
generate: async (envelope, ctx, next) => {
if (ctx.onChunk) {
// Send a raw object instead of GenerateResponseChunk
ctx.onChunk({ content: [{ text: 'raw content' }] } as any);
}
return next(envelope, ctx);
},
}));
const chunks: any[] = [];
const { stream, response } = generateStream(registry, {
model: mockModel,
prompt: 'test',
use: [rawChunkMiddleware()],
});
for await (const chunk of stream) {
chunks.push(chunk);
}
await response;
const rawChunk = chunks.find((c) => c.text === 'raw content');
assert.ok(rawChunk, 'Should find the raw content chunk');
assert.strictEqual(rawChunk.index, 0, 'Should have index 0');
assert.deepStrictEqual(
rawChunk.previousChunks,
[],
'Should have empty previousChunks'
);
});
it('accumulates middleware chunks into sharedPreviousChunks for subsequent model chunks', async () => {
const mockStreamingModel = defineModel(
registry,
{ name: 'mockStreamingModel' },
async (req, streamingCallback) => {
if (streamingCallback) {
// The model emits a chunk
streamingCallback({ content: [{ text: 'model chunk' }] });
}
return {
message: { role: 'model', content: [{ text: 'done' }] },
};
}
);
const rawChunkMiddleware = generateMiddleware(
{ name: 'rawChunk2' },
() => ({
generate: async (envelope, ctx, next) => {
if (ctx.onChunk) {
// Middleware emits a raw chunk BEFORE the model runs
ctx.onChunk({ content: [{ text: 'middleware chunk ' }] } as any);
}
return next(envelope, ctx);
},
})
);
const chunks: GenerateResponseChunk[] = [];
const { stream, response } = generateStream(registry, {
model: mockStreamingModel,
prompt: 'test',
use: [rawChunkMiddleware()],
});
for await (const chunk of stream) {
chunks.push(chunk);
}
await response;
// We expect 2 chunks in the stream
assert.strictEqual(chunks.length, 2);
assert.strictEqual(chunks[0].text, 'middleware chunk ');
assert.strictEqual(chunks[1].text, 'model chunk');
// CRITICAL ASSERTION: The model chunk should have the middleware chunk in its previousChunks!
assert.strictEqual(chunks[1].previousChunks!.length, 1);
assert.strictEqual(
chunks[1].previousChunks![0].content[0].text,
'middleware chunk '
);
assert.strictEqual(
chunks[1].accumulatedText,
'middleware chunk model chunk'
);
});
});