UNPKG

@genkit-ai/google-cloud

Version:

Genkit AI framework plugin for Google Cloud Platform including Firestore trace/state store and deployment helpers for Cloud Functions for Firebase.

205 lines 7.3 kB
"use strict"; var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod); var model_armor_exports = {}; __export(model_armor_exports, { modelArmor: () => modelArmor }); module.exports = __toCommonJS(model_armor_exports); var import_modelarmor = require("@google-cloud/modelarmor"); var import_genkit = require("genkit"); var import_tracing = require("genkit/tracing"); function extractText(parts) { return parts.map((p) => p.text || "").join(""); } function applySdp(messages, targetIndex, result, options) { const sdpFilterResult = result.filterResults?.["sdp"]?.sdpFilterResult; if (!sdpFilterResult) { return { sdpApplied: false, messages }; } if (typeof options.applyDeidentificationResults === "function") { const newMessages = options.applyDeidentificationResults({ messages, sdpResult: sdpFilterResult }); if (!newMessages) { return { sdpApplied: false, messages }; } const sdpApplied = !!sdpFilterResult.deidentifyResult?.data?.text; return { sdpApplied, messages: newMessages }; } if (options.applyDeidentificationResults === true) { const deidentifyResult = sdpFilterResult.deidentifyResult; if (deidentifyResult && deidentifyResult.data?.text) { const targetMessage = messages[targetIndex]; const nonTextParts = targetMessage.content.filter((p) => !p.text); const newContent = [ ...nonTextParts, { text: deidentifyResult.data.text } ]; const newMessages = [...messages]; newMessages[targetIndex] = { ...targetMessage, content: newContent }; return { sdpApplied: true, messages: newMessages }; } } return { sdpApplied: false, messages }; } function shouldBlock(result, options, sdpApplied) { if (result.filterMatchState !== "MATCH_FOUND") { return false; } if (options.strictSdpEnforcement && sdpApplied) { return true; } if (result.filterResults) { for (const [key, filterResult] of Object.entries(result.filterResults)) { if (options.filters && !options.filters.includes(key)) continue; if (key === "sdp" && sdpApplied) continue; const nestedResult = Object.values(filterResult)[0]; if (nestedResult?.matchState === "MATCH_FOUND") { return true; } } } return false; } async function sanitizeUserPrompt(req, client, options) { let targetMessageIndex = -1; for (let i = req.messages.length - 1; i >= 0; i--) { if (req.messages[i].role === "user") { targetMessageIndex = i; break; } } if (targetMessageIndex !== -1) { const userMessage = req.messages[targetMessageIndex]; const promptText = extractText(userMessage.content); if (promptText) { await (0, import_tracing.runInNewSpan)( { metadata: { name: "sanitizeUserPrompt" } }, async (meta) => { meta.input = { name: options.templateName, userPromptData: { text: promptText } }; const [response] = await client.sanitizeUserPrompt({ name: options.templateName, userPromptData: { text: promptText } }); meta.output = response; if (response.sanitizationResult) { const result = response.sanitizationResult; const { sdpApplied, messages: modifiedMessages } = applySdp( req.messages, targetMessageIndex, result, options ); if (sdpApplied || typeof options.applyDeidentificationResults === "function") { req.messages = modifiedMessages; } if (shouldBlock(result, options, sdpApplied)) { throw new import_genkit.GenkitError({ status: "PERMISSION_DENIED", message: "Model Armor blocked user prompt.", detail: result }); } } } ); } } } async function sanitizeModelResponse(response, client, options) { const usingMessageProp = !!response.message; const candidates = response.message ? [{ index: 0, message: response.message, finishReason: "stop" }] : response.candidates || []; for (const candidate of candidates) { const modelText = extractText(candidate.message.content); if (modelText) { await (0, import_tracing.runInNewSpan)( { metadata: { name: "sanitizeModelResponse" } }, async (meta) => { meta.input = { name: options.templateName, modelResponseData: { text: modelText } }; const [apiResponse] = await client.sanitizeModelResponse({ name: options.templateName, modelResponseData: { text: modelText } }); meta.output = apiResponse; if (apiResponse.sanitizationResult) { const result = apiResponse.sanitizationResult; const { sdpApplied, messages: modifiedMessages } = applySdp( [candidate.message], 0, result, options ); if (sdpApplied || typeof options.applyDeidentificationResults === "function") { candidate.message = modifiedMessages[0]; } if (shouldBlock(result, options, sdpApplied)) { throw new import_genkit.GenkitError({ status: "PERMISSION_DENIED", message: "Model Armor blocked model response.", detail: result }); } } } ); } } if (usingMessageProp && candidates.length > 0) { response.message = candidates[0].message; } } function modelArmor(options) { const client = options.client || new import_modelarmor.ModelArmorClient(options.clientOptions); const protectionTarget = options.protectionTarget ?? "all"; const protectUserPrompt = protectionTarget === "all" || protectionTarget === "userPrompt"; const protectModelResponse = protectionTarget === "all" || protectionTarget === "modelResponse"; return async (req, next) => { if (protectUserPrompt) { await sanitizeUserPrompt(req, client, options); } const response = await next(req); if (protectModelResponse) { await sanitizeModelResponse(response, client, options); } return response; }; } // Annotate the CommonJS export names for ESM import in node: 0 && (module.exports = { modelArmor }); //# sourceMappingURL=model-armor.js.map