@genkit-ai/checks
Version:
Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.
1 lines • 3.94 kB
Source Map (JSON)
{"version":3,"sources":["../src/middleware.ts"],"sourcesContent":["/**\n * Copyright 2024 Google LLC\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\nimport type { ModelMiddleware } from 'genkit/model';\nimport type { GoogleAuth } from 'google-auth-library';\nimport { Guardrails } from './guardrails';\nimport type { ChecksEvaluationMetric } from './metrics';\n\nexport function checksMiddleware(options: {\n auth: GoogleAuth;\n metrics: ChecksEvaluationMetric[];\n projectId?: string;\n}): ModelMiddleware {\n const guardrails = new Guardrails(options.auth, options?.projectId);\n\n const classifyContent = async (content: string) => {\n const response = await guardrails.classifyContent(content, options.metrics);\n\n // Filter for violations\n const violatedPolicies = response.policyResults.filter(\n (policy) => policy.violationResult === 'VIOLATIVE'\n );\n\n return violatedPolicies;\n };\n\n return async (req, next) => {\n for (const message of req.messages) {\n for (const content of message.content) {\n if (content.text) {\n const violatedPolicies = await classifyContent(content.text);\n\n // If any input message violates a checks policy. Stop processing,\n // return a blocked response and list of violated policies.\n if (violatedPolicies.length > 0) {\n return {\n finishReason: 'blocked',\n finishMessage: `Model input violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(' ')}], further processing blocked.`,\n };\n }\n }\n }\n }\n\n const generatedContent = await next(req);\n\n for (const candidate of generatedContent.candidates ?? []) {\n for (const content of candidate.message.content ?? []) {\n if (content.text) {\n const violatedPolicies = await classifyContent(content.text);\n\n // If the output message violates a checks policy. Stop processing,\n // return a blocked response and list of violated policies.\n if (violatedPolicies.length > 0) {\n return {\n finishReason: 'blocked',\n finishMessage: `Model output violated Checks policies: [${violatedPolicies.map((result) => result.policyType).join(' ')}], output blocked.`,\n };\n }\n }\n }\n }\n\n return generatedContent;\n };\n}\n"],"mappings":"AAkBA,SAAS,kBAAkB;AAGpB,SAAS,iBAAiB,SAIb;AAClB,QAAM,aAAa,IAAI,WAAW,QAAQ,MAAM,SAAS,SAAS;AAElE,QAAM,kBAAkB,OAAO,YAAoB;AACjD,UAAM,WAAW,MAAM,WAAW,gBAAgB,SAAS,QAAQ,OAAO;AAG1E,UAAM,mBAAmB,SAAS,cAAc;AAAA,MAC9C,CAAC,WAAW,OAAO,oBAAoB;AAAA,IACzC;AAEA,WAAO;AAAA,EACT;AAEA,SAAO,OAAO,KAAK,SAAS;AAC1B,eAAW,WAAW,IAAI,UAAU;AAClC,iBAAW,WAAW,QAAQ,SAAS;AACrC,YAAI,QAAQ,MAAM;AAChB,gBAAM,mBAAmB,MAAM,gBAAgB,QAAQ,IAAI;AAI3D,cAAI,iBAAiB,SAAS,GAAG;AAC/B,mBAAO;AAAA,cACL,cAAc;AAAA,cACd,eAAe,0CAA0C,iBAAiB,IAAI,CAAC,WAAW,OAAO,UAAU,EAAE,KAAK,GAAG,CAAC;AAAA,YACxH;AAAA,UACF;AAAA,QACF;AAAA,MACF;AAAA,IACF;AAEA,UAAM,mBAAmB,MAAM,KAAK,GAAG;AAEvC,eAAW,aAAa,iBAAiB,cAAc,CAAC,GAAG;AACzD,iBAAW,WAAW,UAAU,QAAQ,WAAW,CAAC,GAAG;AACrD,YAAI,QAAQ,MAAM;AAChB,gBAAM,mBAAmB,MAAM,gBAAgB,QAAQ,IAAI;AAI3D,cAAI,iBAAiB,SAAS,GAAG;AAC/B,mBAAO;AAAA,cACL,cAAc;AAAA,cACd,eAAe,2CAA2C,iBAAiB,IAAI,CAAC,WAAW,OAAO,UAAU,EAAE,KAAK,GAAG,CAAC;AAAA,YACzH;AAAA,UACF;AAAA,QACF;AAAA,MACF;AAAA,IACF;AAEA,WAAO;AAAA,EACT;AACF;","names":[]}