UNPKG

@genkit-ai/checks

Version:

Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.

108 lines 2.87 kB
import { z } from "genkit"; import { runInNewSpan } from "genkit/tracing"; import { isConfig } from "./metrics"; function checksEvaluators(ai, auth, metrics, projectId) { const policy_configs = metrics.map( (metric) => { const metricType = isConfig(metric) ? metric.type : metric; const threshold = isConfig(metric) ? metric.threshold : void 0; return { type: metricType, threshold }; } ); return createPolicyEvaluator(projectId, auth, ai, policy_configs); } const ResponseSchema = z.object({ policyResults: z.array( z.object({ policyType: z.string(), score: z.number(), violationResult: z.string() }) ) }); function createPolicyEvaluator(projectId, auth, ai, policy_config) { return ai.defineEvaluator( { name: "checks/guardrails", displayName: "checks/guardrails", definition: `Evaluates input text against the Checks ${policy_config.map((policy) => policy.type)} policies.` }, async (datapoint) => { const partialRequest = { input: { text_input: { content: datapoint.output } }, policies: policy_config.map((config) => { return { policy_type: config.type, threshold: config.threshold }; }) }; const response = await checksEvalInstance( ai, projectId, auth, partialRequest, ResponseSchema ); const evaluationResults = response.policyResults.map((result) => { return { id: result.policyType, score: result.score, details: { reasoning: `Status ${result.violationResult}` } }; }); return { evaluation: evaluationResults, testCaseId: datapoint.testCaseId }; } ); } async function checksEvalInstance(ai, projectId, auth, partialRequest, responseSchema) { return await runInNewSpan( ai, { metadata: { name: "EvaluationService#evaluateInstances" } }, async (metadata, _otSpan) => { const request = { ...partialRequest }; metadata.input = request; const client = await auth.getClient(); const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent"; const response = await client.request({ url, method: "POST", body: JSON.stringify(request), headers: { "x-goog-user-project": projectId, "Content-Type": "application/json" } }); metadata.output = response.data; try { return responseSchema.parse(response.data); } catch (e) { throw new Error(`Error parsing ${url} API response: ${e}`); } } ); } export { checksEvaluators }; //# sourceMappingURL=evaluation.mjs.map