@genkit-ai/ai
Version:
Genkit AI framework generative AI APIs.
166 lines (147 loc) • 4.84 kB
text/typescript
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { stripUndefinedProps, z } from '@genkit-ai/core';
import { initNodeFeatures } from '@genkit-ai/core/node';
import { Registry } from '@genkit-ai/core/registry';
import * as assert from 'assert';
import { readFileSync } from 'fs';
import { beforeEach, describe, it } from 'node:test';
import { parse } from 'yaml';
import {
defineGenerateAction,
type GenerateAction,
} from '../../src/generate/action.js';
import { generateMiddleware } from '../../src/generate/middleware.js';
import {
GenerateActionOptionsSchema,
GenerateResponseChunkSchema,
GenerateResponseSchema,
type GenerateResponseChunkData,
} from '../../src/model.js';
import { defineTool, tool } from '../../src/tool.js';
import { defineProgrammableModel, type ProgrammableModel } from '../helpers.js';
initNodeFeatures();
const SpecSuiteSchema = z
.object({
tests: z.array(
z
.object({
name: z.string(),
input: GenerateActionOptionsSchema,
streamChunks: z
.array(z.array(GenerateResponseChunkSchema))
.optional(),
modelResponses: z.array(GenerateResponseSchema),
expectResponse: GenerateResponseSchema.optional(),
stream: z.boolean().optional(),
expectChunks: z.array(GenerateResponseChunkSchema).optional(),
})
.strict()
),
})
.strict();
describe('spec', () => {
let registry: Registry;
let pm: ProgrammableModel;
beforeEach(() => {
registry = new Registry();
defineGenerateAction(registry);
pm = defineProgrammableModel(registry);
defineTool(
registry,
{ name: 'testTool', description: 'description' },
async () => 'tool called'
);
});
SpecSuiteSchema.parse(
parse(readFileSync('../../tests/specs/generate.yaml', 'utf-8'))
).tests.forEach((test) => {
it(test.name, async () => {
if (test.modelResponses || test.streamChunks) {
let reqCounter = 0;
pm.handleResponse = async (req, sc) => {
if (test.streamChunks && sc) {
test.streamChunks[reqCounter].forEach(sc);
}
return test.modelResponses?.[reqCounter++]!;
};
}
const action = (await registry.lookupAction(
'/util/generate'
)) as GenerateAction;
if (test.stream) {
const { output, stream } = action.stream(test.input);
const chunks = [] as GenerateResponseChunkData[];
for await (const chunk of stream) {
chunks.push(stripUndefinedProps(chunk));
}
assert.deepStrictEqual(chunks, test.expectChunks);
assert.deepStrictEqual(
stripUndefinedProps(await output),
test.expectResponse
);
} else {
const response = await action(test.input);
assert.deepStrictEqual(
stripUndefinedProps(response),
test.expectResponse
);
}
});
});
});
describe('generateAction middleware injection', () => {
let registry: Registry;
let pm: ProgrammableModel;
beforeEach(() => {
registry = new Registry();
defineGenerateAction(registry);
pm = defineProgrammableModel(registry);
});
it('supports injecting tools through middleware definitions directly via action route', async () => {
const injectedTool = tool(
{
name: 'injectedTool',
description: 'desc',
inputSchema: z.object({ arg: z.string() }),
},
async (input) => `Result: ${input.arg}`
);
let toolsSeen = false;
pm.handleResponse = async (req) => {
if (req.tools?.find((t) => t.name === 'injectedTool')) {
toolsSeen = true;
}
return {
message: { role: 'model', content: [{ text: 'done' }] },
finishReason: 'stop',
} as any;
};
const dummyMw = generateMiddleware({ name: 'dummyMw' }, () => ({
tools: [injectedTool],
}));
const action = await registry.lookupAction('/util/generate');
await action({
model: 'programmableModel',
messages: [{ role: 'user', content: [{ text: 'test' }] }],
use: [dummyMw()],
} as any);
assert.ok(
toolsSeen,
'Tool was not successfully passed to the model from action generated route.'
);
});
});