@genkit-ai/checks
Version:
Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.
108 lines • 2.87 kB
JavaScript
import { z } from "genkit";
import { runInNewSpan } from "genkit/tracing";
import {
isConfig
} from "./metrics";
function checksEvaluators(ai, auth, metrics, projectId) {
const policy_configs = metrics.map(
(metric) => {
const metricType = isConfig(metric) ? metric.type : metric;
const threshold = isConfig(metric) ? metric.threshold : void 0;
return {
type: metricType,
threshold
};
}
);
return createPolicyEvaluator(projectId, auth, ai, policy_configs);
}
const ResponseSchema = z.object({
policyResults: z.array(
z.object({
policyType: z.string(),
score: z.number(),
violationResult: z.string()
})
)
});
function createPolicyEvaluator(projectId, auth, ai, policy_config) {
return ai.defineEvaluator(
{
name: "checks/guardrails",
displayName: "checks/guardrails",
definition: `Evaluates input text against the Checks ${policy_config.map((policy) => policy.type)} policies.`
},
async (datapoint) => {
const partialRequest = {
input: {
text_input: {
content: datapoint.output
}
},
policies: policy_config.map((config) => {
return {
policy_type: config.type,
threshold: config.threshold
};
})
};
const response = await checksEvalInstance(
ai,
projectId,
auth,
partialRequest,
ResponseSchema
);
const evaluationResults = response.policyResults.map((result) => {
return {
id: result.policyType,
score: result.score,
details: {
reasoning: `Status ${result.violationResult}`
}
};
});
return {
evaluation: evaluationResults,
testCaseId: datapoint.testCaseId
};
}
);
}
async function checksEvalInstance(ai, projectId, auth, partialRequest, responseSchema) {
return await runInNewSpan(
ai,
{
metadata: {
name: "EvaluationService#evaluateInstances"
}
},
async (metadata, _otSpan) => {
const request = {
...partialRequest
};
metadata.input = request;
const client = await auth.getClient();
const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent";
const response = await client.request({
url,
method: "POST",
body: JSON.stringify(request),
headers: {
"x-goog-user-project": projectId,
"Content-Type": "application/json"
}
});
metadata.output = response.data;
try {
return responseSchema.parse(response.data);
} catch (e) {
throw new Error(`Error parsing ${url} API response: ${e}`);
}
}
);
}
export {
checksEvaluators
};
//# sourceMappingURL=evaluation.mjs.map