UNPKG

@aikidosec/firewall

Version:

Zen by Aikido is an embedded Web Application Firewall that autonomously protects Node.js apps against common and critical attacks

158 lines (157 loc) 6.25 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.AiSDK = void 0; const wrapExport_1 = require("../agent/hooks/wrapExport"); const isPlainObject_1 = require("../helpers/isPlainObject"); class AiSDK { inspectAiCall(agent, args, response) { if (!this.isResult(response)) { return; } const provider = this.getProviderFromArgs(args); if (!provider) { return; } const modelName = this.getModelName(response); const aiStats = agent.getAIStatistics(); aiStats.onAICall({ provider: provider, model: modelName, inputTokens: response.usage.promptTokens, outputTokens: response.usage.completionTokens, }); } isResult(result) { if (result && typeof result === "object" && // It is not a plain object !Array.isArray(result) && "usage" in result && (0, isPlainObject_1.isPlainObject)(result.usage) && typeof result.usage.completionTokens === "number" && typeof result.usage.promptTokens === "number" && typeof result.usage.totalTokens === "number" && "response" in result && result.response && (0, isPlainObject_1.isPlainObject)(result.response) && typeof result.response.modelId === "string") { return true; } return false; } getProviderFromArgs(args) { if (!Array.isArray(args) || args.length === 0) { return undefined; } const firstArg = args[0]; if (!(0, isPlainObject_1.isPlainObject)(firstArg)) { return undefined; } if (!firstArg.model || typeof firstArg.model !== "object") { return undefined; } if (!("provider" in firstArg.model) || typeof firstArg.model.provider !== "string") { return undefined; } let providerName = firstArg.model.provider; if (providerName.includes(".")) { // e.g. google.generativeai providerName = providerName.split(".")[0]; } if (providerName === "amazon-bedrock") { return "bedrock"; // Normalize amazon-bedrock to bedrock } if (providerName.includes("-")) { // e.g. azure-openai providerName = providerName.split("-")[0]; } if (providerName === "google") { return "gemini"; // Normalize google to gemini } return providerName; } getModelName(response) { let modelName = response.response.modelId; if (modelName.startsWith("models/")) { modelName = modelName.slice(7); // Remove "models/" prefix } return modelName; } getInterceptors() { return { kind: "ai_op", modifyReturnValue: (args, returnValue, agent) => { if (returnValue instanceof Promise) { // Inspect the response after the promise resolves, it won't change the original promise returnValue.then((response) => { try { this.inspectAiCall(agent, args, response); } catch { // If we don't catch these errors, it will result in an unhandled promise rejection! } }); } return returnValue; }, }; } getStreamInterceptors() { return { kind: "ai_op", modifyReturnValue: (args, returnValue, agent) => { if (!returnValue || typeof returnValue !== "object" || !("response" in returnValue) || !(returnValue.response instanceof Promise) || !("usage" in returnValue) || !(returnValue.usage instanceof Promise)) { return returnValue; } Promise.allSettled([returnValue.response, returnValue.usage]).then((promiseResults) => { const response = promiseResults[0].status === "fulfilled" ? promiseResults[0].value : undefined; const usage = promiseResults[1].status === "fulfilled" ? promiseResults[1].value : undefined; if (!response || !usage) { return; } try { this.inspectAiCall(agent, args, { response, usage, }); } catch { // If we don't catch these errors, it will result in an unhandled promise rejection! } }); return returnValue; }, }; } wrap(hooks) { hooks .addPackage("ai") .withVersion("^4.0.0") .onRequire((exports, pkgInfo) => { // Can't wrap it directly because it's a readonly proxy const generateTextFunc = exports.generateText; const generateObjectFunc = exports.generateObject; const streamTextFunc = exports.streamText; const streamObjectFunc = exports.streamObject; const interceptors = this.getInterceptors(); const streamInterceptors = this.getStreamInterceptors(); return { ...exports, generateText: (0, wrapExport_1.wrapExport)(generateTextFunc, undefined, pkgInfo, interceptors), generateObject: (0, wrapExport_1.wrapExport)(generateObjectFunc, undefined, pkgInfo, interceptors), streamText: (0, wrapExport_1.wrapExport)(streamTextFunc, undefined, pkgInfo, streamInterceptors), streamObject: (0, wrapExport_1.wrapExport)(streamObjectFunc, undefined, pkgInfo, streamInterceptors), }; }); } } exports.AiSDK = AiSDK;