@genkit-ai/ai
Version:
Genkit AI framework generative AI APIs.
178 lines • 5.66 kB
JavaScript
import { z } from "@genkit-ai/core";
import { runInNewSpan } from "@genkit-ai/core/tracing";
import * as assert from "assert";
import { generate } from "../generate";
import { defineTool } from "../tool";
const tests = {
"basic hi": async (registry, model) => {
const response = await generate(registry, {
model,
prompt: 'just say "Hi", literally'
});
const got = response.text.trim();
assert.match(got, /Hi/i);
},
multimodal: async (registry, model) => {
const resolvedModel = await registry.lookupAction(
`/model/${model}`
);
if (!resolvedModel.__action.metadata?.model.supports?.media) {
skip();
}
const response = await generate(registry, {
model,
prompt: [
{
media: {
url: ""
}
},
{
text: "what math operation is this? plus, minus, multiply or divide?"
}
]
});
const want = /plus/i;
const got = response.text.trim();
assert.match(got, want);
},
history: async (registry, model) => {
const resolvedModel = await registry.lookupAction(
`/model/${model}`
);
if (!resolvedModel.__action.metadata?.model.supports?.multiturn) {
skip();
}
const response1 = await generate(registry, {
model,
prompt: "My name is Glorb"
});
const response = await generate(registry, {
model,
prompt: "What's my name?",
messages: response1.messages
});
const got = response.text.trim();
assert.match(got, /Glorb/);
},
"system prompt": async (registry, model) => {
const { text } = await generate(registry, {
model,
prompt: "Hi",
messages: [
{
role: "system",
content: [
{
text: 'If the user says "Hi", just say "Bye" '
}
]
}
]
});
const want = "Bye";
const got = text.trim();
assert.equal(got, want);
},
"structured output": async (registry, model) => {
const response = await generate(registry, {
model,
prompt: "extract data as json from: Jack was a Lumberjack",
output: {
format: "json",
schema: z.object({
name: z.string(),
occupation: z.string()
})
}
});
const want = {
name: "Jack",
occupation: "Lumberjack"
};
const got = response.output;
assert.deepEqual(want, got);
},
"tool calling": async (registry, model) => {
const resolvedModel = await registry.lookupAction(
`/model/${model}`
);
if (!resolvedModel.__action.metadata?.model.supports?.tools) {
skip();
}
const { text } = await generate(registry, {
model,
prompt: "what is a gablorken of 2? use provided tool",
tools: ["gablorkenTool"]
});
const got = text.trim();
assert.match(got, /9.407/);
}
};
async function testModels(registry, models) {
defineTool(
registry,
{
name: "gablorkenTool",
description: "use when need to calculate a gablorken",
inputSchema: z.object({
value: z.number()
}),
outputSchema: z.number()
},
async (input) => {
return Math.pow(input.value, 3) + 1.407;
}
);
return await runInNewSpan(
registry,
{ metadata: { name: "testModels" } },
async () => {
const report = [];
for (const test of Object.keys(tests)) {
await runInNewSpan(registry, { metadata: { name: test } }, async () => {
report.push({
description: test,
models: []
});
const caseReport = report[report.length - 1];
for (const model of models) {
caseReport.models.push({
name: model,
passed: true
// optimistically
});
const modelReport = caseReport.models[caseReport.models.length - 1];
try {
await tests[test](registry, model);
} catch (e) {
modelReport.passed = false;
if (e instanceof SkipTestError) {
modelReport.skipped = true;
} else if (e instanceof Error) {
modelReport.error = {
message: e.message,
stack: e.stack
};
} else {
modelReport.error = {
message: `${e}`
};
}
}
}
});
}
return report;
}
);
}
class SkipTestError extends Error {
}
function skip() {
throw new SkipTestError();
}
export {
testModels
};
//# sourceMappingURL=model-tester.mjs.map