UNPKG

@genkit-ai/ai

Version:

Genkit AI framework generative AI APIs.

178 lines 5.66 kB
import { z } from "@genkit-ai/core"; import { runInNewSpan } from "@genkit-ai/core/tracing"; import * as assert from "assert"; import { generate } from "../generate"; import { defineTool } from "../tool"; const tests = { "basic hi": async (registry, model) => { const response = await generate(registry, { model, prompt: 'just say "Hi", literally' }); const got = response.text.trim(); assert.match(got, /Hi/i); }, multimodal: async (registry, model) => { const resolvedModel = await registry.lookupAction( `/model/${model}` ); if (!resolvedModel.__action.metadata?.model.supports?.media) { skip(); } const response = await generate(registry, { model, prompt: [ { media: { url: "" } }, { text: "what math operation is this? plus, minus, multiply or divide?" } ] }); const want = /plus/i; const got = response.text.trim(); assert.match(got, want); }, history: async (registry, model) => { const resolvedModel = await registry.lookupAction( `/model/${model}` ); if (!resolvedModel.__action.metadata?.model.supports?.multiturn) { skip(); } const response1 = await generate(registry, { model, prompt: "My name is Glorb" }); const response = await generate(registry, { model, prompt: "What's my name?", messages: response1.messages }); const got = response.text.trim(); assert.match(got, /Glorb/); }, "system prompt": async (registry, model) => { const { text } = await generate(registry, { model, prompt: "Hi", messages: [ { role: "system", content: [ { text: 'If the user says "Hi", just say "Bye" ' } ] } ] }); const want = "Bye"; const got = text.trim(); assert.equal(got, want); }, "structured output": async (registry, model) => { const response = await generate(registry, { model, prompt: "extract data as json from: Jack was a Lumberjack", output: { format: "json", schema: z.object({ name: z.string(), occupation: z.string() }) } }); const want = { name: "Jack", occupation: "Lumberjack" }; const got = response.output; assert.deepEqual(want, got); }, "tool calling": async (registry, model) => { const resolvedModel = await registry.lookupAction( `/model/${model}` ); if (!resolvedModel.__action.metadata?.model.supports?.tools) { skip(); } const { text } = await generate(registry, { model, prompt: "what is a gablorken of 2? use provided tool", tools: ["gablorkenTool"] }); const got = text.trim(); assert.match(got, /9.407/); } }; async function testModels(registry, models) { defineTool( registry, { name: "gablorkenTool", description: "use when need to calculate a gablorken", inputSchema: z.object({ value: z.number() }), outputSchema: z.number() }, async (input) => { return Math.pow(input.value, 3) + 1.407; } ); return await runInNewSpan( registry, { metadata: { name: "testModels" } }, async () => { const report = []; for (const test of Object.keys(tests)) { await runInNewSpan(registry, { metadata: { name: test } }, async () => { report.push({ description: test, models: [] }); const caseReport = report[report.length - 1]; for (const model of models) { caseReport.models.push({ name: model, passed: true // optimistically }); const modelReport = caseReport.models[caseReport.models.length - 1]; try { await tests[test](registry, model); } catch (e) { modelReport.passed = false; if (e instanceof SkipTestError) { modelReport.skipped = true; } else if (e instanceof Error) { modelReport.error = { message: e.message, stack: e.stack }; } else { modelReport.error = { message: `${e}` }; } } } }); } return report; } ); } class SkipTestError extends Error { } function skip() { throw new SkipTestError(); } export { testModels }; //# sourceMappingURL=model-tester.mjs.map