@genkit-ai/google-cloud
Version:
Genkit AI framework plugin for Google Cloud Platform including Firestore trace/state store and deployment helpers for Cloud Functions for Firebase.
205 lines • 7.3 kB
JavaScript
"use strict";
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
var model_armor_exports = {};
__export(model_armor_exports, {
modelArmor: () => modelArmor
});
module.exports = __toCommonJS(model_armor_exports);
var import_modelarmor = require("@google-cloud/modelarmor");
var import_genkit = require("genkit");
var import_tracing = require("genkit/tracing");
function extractText(parts) {
return parts.map((p) => p.text || "").join("");
}
function applySdp(messages, targetIndex, result, options) {
const sdpFilterResult = result.filterResults?.["sdp"]?.sdpFilterResult;
if (!sdpFilterResult) {
return { sdpApplied: false, messages };
}
if (typeof options.applyDeidentificationResults === "function") {
const newMessages = options.applyDeidentificationResults({
messages,
sdpResult: sdpFilterResult
});
if (!newMessages) {
return { sdpApplied: false, messages };
}
const sdpApplied = !!sdpFilterResult.deidentifyResult?.data?.text;
return { sdpApplied, messages: newMessages };
}
if (options.applyDeidentificationResults === true) {
const deidentifyResult = sdpFilterResult.deidentifyResult;
if (deidentifyResult && deidentifyResult.data?.text) {
const targetMessage = messages[targetIndex];
const nonTextParts = targetMessage.content.filter((p) => !p.text);
const newContent = [
...nonTextParts,
{ text: deidentifyResult.data.text }
];
const newMessages = [...messages];
newMessages[targetIndex] = { ...targetMessage, content: newContent };
return {
sdpApplied: true,
messages: newMessages
};
}
}
return { sdpApplied: false, messages };
}
function shouldBlock(result, options, sdpApplied) {
if (result.filterMatchState !== "MATCH_FOUND") {
return false;
}
if (options.strictSdpEnforcement && sdpApplied) {
return true;
}
if (result.filterResults) {
for (const [key, filterResult] of Object.entries(result.filterResults)) {
if (options.filters && !options.filters.includes(key)) continue;
if (key === "sdp" && sdpApplied) continue;
const nestedResult = Object.values(filterResult)[0];
if (nestedResult?.matchState === "MATCH_FOUND") {
return true;
}
}
}
return false;
}
async function sanitizeUserPrompt(req, client, options) {
let targetMessageIndex = -1;
for (let i = req.messages.length - 1; i >= 0; i--) {
if (req.messages[i].role === "user") {
targetMessageIndex = i;
break;
}
}
if (targetMessageIndex !== -1) {
const userMessage = req.messages[targetMessageIndex];
const promptText = extractText(userMessage.content);
if (promptText) {
await (0, import_tracing.runInNewSpan)(
{ metadata: { name: "sanitizeUserPrompt" } },
async (meta) => {
meta.input = {
name: options.templateName,
userPromptData: {
text: promptText
}
};
const [response] = await client.sanitizeUserPrompt({
name: options.templateName,
userPromptData: {
text: promptText
}
});
meta.output = response;
if (response.sanitizationResult) {
const result = response.sanitizationResult;
const { sdpApplied, messages: modifiedMessages } = applySdp(
req.messages,
targetMessageIndex,
result,
options
);
if (sdpApplied || typeof options.applyDeidentificationResults === "function") {
req.messages = modifiedMessages;
}
if (shouldBlock(result, options, sdpApplied)) {
throw new import_genkit.GenkitError({
status: "PERMISSION_DENIED",
message: "Model Armor blocked user prompt.",
detail: result
});
}
}
}
);
}
}
}
async function sanitizeModelResponse(response, client, options) {
const usingMessageProp = !!response.message;
const candidates = response.message ? [{ index: 0, message: response.message, finishReason: "stop" }] : response.candidates || [];
for (const candidate of candidates) {
const modelText = extractText(candidate.message.content);
if (modelText) {
await (0, import_tracing.runInNewSpan)(
{ metadata: { name: "sanitizeModelResponse" } },
async (meta) => {
meta.input = {
name: options.templateName,
modelResponseData: {
text: modelText
}
};
const [apiResponse] = await client.sanitizeModelResponse({
name: options.templateName,
modelResponseData: {
text: modelText
}
});
meta.output = apiResponse;
if (apiResponse.sanitizationResult) {
const result = apiResponse.sanitizationResult;
const { sdpApplied, messages: modifiedMessages } = applySdp(
[candidate.message],
0,
result,
options
);
if (sdpApplied || typeof options.applyDeidentificationResults === "function") {
candidate.message = modifiedMessages[0];
}
if (shouldBlock(result, options, sdpApplied)) {
throw new import_genkit.GenkitError({
status: "PERMISSION_DENIED",
message: "Model Armor blocked model response.",
detail: result
});
}
}
}
);
}
}
if (usingMessageProp && candidates.length > 0) {
response.message = candidates[0].message;
}
}
function modelArmor(options) {
const client = options.client || new import_modelarmor.ModelArmorClient(options.clientOptions);
const protectionTarget = options.protectionTarget ?? "all";
const protectUserPrompt = protectionTarget === "all" || protectionTarget === "userPrompt";
const protectModelResponse = protectionTarget === "all" || protectionTarget === "modelResponse";
return async (req, next) => {
if (protectUserPrompt) {
await sanitizeUserPrompt(req, client, options);
}
const response = await next(req);
if (protectModelResponse) {
await sanitizeModelResponse(response, client, options);
}
return response;
};
}
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
modelArmor
});
//# sourceMappingURL=model-armor.js.map