@future-agi/ai-evaluation
Version:
We help GenAI teams maintain high-accuracy for their Models in production.
271 lines • 14.2 kB
JavaScript
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
return new (P || (P = Promise))(function (resolve, reject) {
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
step((generator = generator.apply(thisArg, _arguments || [])).next());
});
};
import { HttpMethod, Routes, InvalidAuthError, SDKException, InvalidValueType, MissingRequiredKey, } from '@future-agi/sdk';
import { EvalResponseHandler, Evaluator } from './evaluator';
import { Templates } from './templates';
const PROTECT_FLASH_ID = "76";
export class Protect {
constructor(options = {}) {
if (options.evaluator) {
this.evaluator = options.evaluator;
}
else {
const fiApiKey = process.env.FI_API_KEY || options.fiApiKey;
const fiSecretKey = process.env.FI_SECRET_KEY || options.fiSecretKey;
const fiBaseUrl = process.env.FI_BASE_URL || options.fiBaseUrl;
if (!fiApiKey || !fiSecretKey) {
throw new InvalidAuthError("API key or secret key is missing for Protect initialization.");
}
this.evaluator = new Evaluator({ fiApiKey, fiSecretKey, fiBaseUrl });
}
this.metric_map = {
"Toxicity": Templates.Toxicity,
"Tone": Templates.Tone,
"Sexism": Templates.Sexist,
"Prompt Injection": Templates.PromptInjection,
"Data Privacy": Templates.DataPrivacyCompliance,
};
}
_check_rule_sync(rule, testCase, timeoutSeconds) {
return __awaiter(this, void 0, void 0, function* () {
var _a, _b;
const templateInfo = this.metric_map[rule.metric];
const templateConfig = { call_type: "protect" };
if (rule.metric === "Data Privacy") {
templateConfig.check_internet = false;
}
const payload = {
inputs: [testCase],
config: {
[templateInfo.eval_id]: templateConfig
},
};
const timeoutMs = Math.max(0, timeoutSeconds * 1000);
const evalResult = yield this.evaluator.request({
method: HttpMethod.POST,
url: `${this.evaluator.baseUrl}/${Routes.evaluate}`,
json: payload,
timeout: timeoutMs
}, EvalResponseHandler);
let reasonText = undefined;
if (evalResult.eval_results && evalResult.eval_results[0]) {
const result = evalResult.eval_results[0];
const detectedValues = result.data || [];
let shouldTrigger = false;
if (rule.type === "any") {
shouldTrigger = detectedValues.some((value) => { var _a; return (_a = rule.contains) === null || _a === void 0 ? void 0 : _a.includes(value); });
}
else if (rule.type === "all") {
shouldTrigger = (_b = (_a = rule.contains) === null || _a === void 0 ? void 0 : _a.every((value) => detectedValues.includes(value))) !== null && _b !== void 0 ? _b : false;
}
if (shouldTrigger) {
const message = rule.action;
if (rule._internal_reason_flag) {
reasonText = result.reason;
}
return [rule.metric, true, message, reasonText];
}
}
return [rule.metric, false, undefined, undefined];
});
}
_process_rules_batch(rules, testCase, remainingTime // in seconds
) {
return __awaiter(this, void 0, void 0, function* () {
let failureMessages = [];
let completedRules = [];
let uncompletedRules = [];
let failureReasons = [];
let failedRule = undefined;
const timeoutPromise = new Promise((resolve) => setTimeout(() => resolve("timeout"), remainingTime * 1000));
const rulePromises = rules.map(rule => this._check_rule_sync(rule, testCase, remainingTime)
.then(value => ({ status: 'fulfilled', value, metric: rule.metric }))
.catch(reason => ({ status: 'rejected', reason, metric: rule.metric })));
const raceResult = yield Promise.race([Promise.all(rulePromises), timeoutPromise]);
if (raceResult === "timeout") {
uncompletedRules = rules.map(r => r.metric);
return [failureMessages, completedRules, uncompletedRules, failureReasons, failedRule];
}
for (const result of raceResult) {
if (result.status === 'fulfilled') {
const [metric, triggered, message, reason_text] = result.value;
completedRules.push(metric);
if (triggered && !failedRule) {
failedRule = metric;
failureMessages.push(message);
if (reason_text) {
failureReasons.push(reason_text);
}
}
}
else {
console.error(`Rule ${result.metric} failed with error:`, result.reason);
}
}
const allMetrics = rules.map(r => r.metric);
uncompletedRules = allMetrics.filter(m => !completedRules.includes(m));
return [failureMessages, completedRules, uncompletedRules, failureReasons, failedRule];
});
}
protect(inputs_1) {
return __awaiter(this, arguments, void 0, function* (inputs, protectRules = null, action = "Response cannot be generated as the input fails the checks", reason = false, timeout = 30000, // milliseconds
useFlash = false) {
var _a;
const timeoutSeconds = timeout / 1000.0;
let protectRulesCopy = protectRules ? JSON.parse(JSON.stringify(protectRules)) : [];
if (useFlash && protectRulesCopy.length === 0) {
protectRulesCopy = [{ metric: "Toxicity" }];
}
else if (useFlash) {
console.log("Note: When using ProtectFlash, Rules are not considered as it performs binary harmful/not harmful classification only.");
}
if (typeof inputs !== 'string') {
throw new InvalidValueType("inputs", inputs, "string");
}
const input_text = inputs;
if (!input_text.trim()) {
throw new InvalidValueType("inputs", input_text, "non-empty string or string with non-whitespace characters");
}
const inputsList = [input_text];
if (useFlash) {
const testCase = { input: inputsList[0], call_type: "protect" };
const templateInfo = this.metric_map[protectRulesCopy[0].metric];
const payload = {
inputs: [testCase],
config: { [PROTECT_FLASH_ID]: { call_type: "protect" } },
protect_flash: true
};
const response = yield this.evaluator.request({
method: HttpMethod.POST,
url: `${this.evaluator.baseUrl}/${Routes.evaluate}`,
json: payload,
timeout: timeoutSeconds * 1000
}, EvalResponseHandler);
if ((_a = response === null || response === void 0 ? void 0 : response.eval_results) === null || _a === void 0 ? void 0 : _a[0]) {
const result = response.eval_results[0];
const isHarmful = result.failure;
return {
status: isHarmful ? "failed" : "passed",
completed_rules: ["ProtectFlash"],
uncompleted_rules: [],
failed_rule: isHarmful ? "ProtectFlash" : null,
messages: isHarmful ? (protectRulesCopy[0].action || action) : inputsList[0],
reasons: isHarmful ? "Content detected as harmful." : "All checks passed",
time_taken: result.runtime ? result.runtime / 1000 : 0,
};
}
else {
return { status: "error", messages: "Evaluation failed", completed_rules: [], uncompleted_rules: ["ProtectFlash"], failed_rule: null, reasons: "No evaluation results returned", time_taken: 0 };
}
}
const testCases = inputsList.map(input_text => ({ input: input_text, call_type: "protect" }));
// Validate rules
if (protectRulesCopy.length === 0) {
throw new InvalidValueType("protect_rules", protectRulesCopy, "non-empty list");
}
const validMetrics = new Set(Object.keys(this.metric_map));
const validTypes = new Set(['any', 'all']);
for (let i = 0; i < protectRulesCopy.length; i++) {
const rule = protectRulesCopy[i];
if (typeof rule !== 'object' || rule === null) {
throw new InvalidValueType(`Rule at index ${i}`, rule, "dictionary");
}
if (!rule.metric) {
throw new MissingRequiredKey(`Rule at index ${i}`, 'metric');
}
if (!validMetrics.has(rule.metric)) {
throw new InvalidValueType(`metric in Rule at index ${i}`, rule.metric, `one of ${[...validMetrics]}`);
}
const isToneMetric = rule.metric === "Tone";
if (isToneMetric) {
if (!rule.contains) {
throw new MissingRequiredKey(`Rule for Tone metric at index ${i}`, "contains");
}
if (!Array.isArray(rule.contains) || rule.contains.length === 0) {
throw new InvalidValueType(`'contains' in Tone rule at index ${i}`, rule.contains, "non-empty list");
}
if (rule.type && !validTypes.has(rule.type)) {
throw new InvalidValueType(`'type' in Tone rule at index ${i}`, rule.type, `one of ${[...validTypes]}`);
}
if (!rule.type) {
rule.type = "any"; // Default
}
}
else {
if (rule.contains) {
throw new SDKException(`'contains' should not be specified for ${rule.metric} metric at index ${i}. Provide it only for 'Tone' metric.`);
}
if (rule.type) {
throw new SDKException(`'type' should not be specified for ${rule.metric} metric at index ${i}. Provide it only for 'Tone' metric.`);
}
rule.contains = ["Failed"];
rule.type = "any";
}
rule._internal_reason_flag = reason;
if (!rule.action)
rule.action = action;
}
const startTime = Date.now();
let allFailureMessages = [];
let allCompletedRules = [];
let allUncompletedRules = [];
let allFailureReasons = [];
let failedRule = undefined;
for (const testCase of testCases) {
const BATCH_SIZE = 5;
for (let i = 0; i < protectRulesCopy.length; i += BATCH_SIZE) {
const elapsedSeconds = (Date.now() - startTime) / 1000;
const remainingTime = Math.max(0, timeoutSeconds - elapsedSeconds);
if (remainingTime <= 0) {
allUncompletedRules.push(...protectRulesCopy.slice(i).map((r) => r.metric));
break;
}
const rulesBatch = protectRulesCopy.slice(i, i + BATCH_SIZE);
const [messages, completed, uncompleted, fReasons, fRule] = yield this._process_rules_batch(rulesBatch, testCase, remainingTime);
allCompletedRules.push(...completed);
allUncompletedRules.push(...uncompleted);
if (fReasons)
allFailureReasons.push(...fReasons);
if (fRule) {
failedRule = fRule;
allFailureMessages = messages;
break;
}
}
if (failedRule)
break;
}
const finalProcessingDurationSeconds = (Date.now() - startTime) / 1000;
const status = failedRule ? "failed" : "passed";
const baseResult = {
status,
completed_rules: [...new Set(allCompletedRules)],
uncompleted_rules: [...new Set(allUncompletedRules)],
failed_rule: failedRule || null,
messages: status === 'failed' ? allFailureMessages[0] : inputsList[0],
time_taken: finalProcessingDurationSeconds,
};
if (reason) {
baseResult.reasons = status === 'failed'
? (allFailureReasons[0] || "A protection rule was triggered.")
: "All checks passed";
}
return baseResult;
});
}
}
/**
* Convenience function to evaluate input strings against protection rules.
*/
export const protect = (inputs, protectRules, action, reason, timeout, useFlash) => {
const protectClient = new Protect();
return protectClient.protect(inputs, protectRules, action, reason, timeout, useFlash);
};
//# sourceMappingURL=protect.js.map