@genkit-ai/checks
Version:
Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.
69 lines • 2.73 kB
JavaScript
;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
var middleware_exports = {};
__export(middleware_exports, {
checksMiddleware: () => checksMiddleware
});
module.exports = __toCommonJS(middleware_exports);
var import_guardrails = require("./guardrails");
function checksMiddleware(options) {
const guardrails = new import_guardrails.Guardrails(options.auth, options?.projectId);
const classifyContent = async (content) => {
const response = await guardrails.classifyContent(content, options.metrics);
const violatedPolicies = response.policyResults.filter(
(policy) => policy.violationResult === "VIOLATIVE"
);
return violatedPolicies;
};
return async (req, next) => {
for (const message of req.messages) {
for (const content of message.content) {
if (content.text) {
const violatedPolicies = await classifyContent(content.text);
if (violatedPolicies.length > 0) {
return {
finishReason: "blocked",
finishMessage: `Model input violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(" ")}], further processing blocked.`
};
}
}
}
}
const generatedContent = await next(req);
for (const candidate of generatedContent.candidates ?? []) {
for (const content of candidate.message.content ?? []) {
if (content.text) {
const violatedPolicies = await classifyContent(content.text);
if (violatedPolicies.length > 0) {
return {
finishReason: "blocked",
finishMessage: `Model output violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(" ")}], output blocked.`
};
}
}
}
}
return generatedContent;
};
}
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
checksMiddleware
});
//# sourceMappingURL=middleware.js.map