UNPKG

@genkit-ai/checks

Version:

Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.

69 lines 2.73 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); var middleware_exports = {}; __export(middleware_exports, { checksMiddleware: () => checksMiddleware }); module.exports = __toCommonJS(middleware_exports); var import_guardrails = require("./guardrails"); function checksMiddleware(options) { const guardrails = new import_guardrails.Guardrails(options.auth, options?.projectId); const classifyContent = async (content) => { const response = await guardrails.classifyContent(content, options.metrics); const violatedPolicies = response.policyResults.filter( (policy) => policy.violationResult === "VIOLATIVE" ); return violatedPolicies; }; return async (req, next) => { for (const message of req.messages) { for (const content of message.content) { if (content.text) { const violatedPolicies = await classifyContent(content.text); if (violatedPolicies.length > 0) { return { finishReason: "blocked", finishMessage: `Model input violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(" ")}], further processing blocked.` }; } } } } const generatedContent = await next(req); for (const candidate of generatedContent.candidates ?? []) { for (const content of candidate.message.content ?? []) { if (content.text) { const violatedPolicies = await classifyContent(content.text); if (violatedPolicies.length > 0) { return { finishReason: "blocked", finishMessage: `Model output violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(" ")}], output blocked.` }; } } } } return generatedContent; }; } // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { checksMiddleware }); //# sourceMappingURL=middleware.js.map