@aikidosec/firewall
Version:
Zen by Aikido is an embedded Web Application Firewall that autonomously protects Node.js apps against common and critical attacks
158 lines (157 loc) • 6.25 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.AiSDK = void 0;
const wrapExport_1 = require("../agent/hooks/wrapExport");
const isPlainObject_1 = require("../helpers/isPlainObject");
class AiSDK {
inspectAiCall(agent, args, response) {
if (!this.isResult(response)) {
return;
}
const provider = this.getProviderFromArgs(args);
if (!provider) {
return;
}
const modelName = this.getModelName(response);
const aiStats = agent.getAIStatistics();
aiStats.onAICall({
provider: provider,
model: modelName,
inputTokens: response.usage.promptTokens,
outputTokens: response.usage.completionTokens,
});
}
isResult(result) {
if (result &&
typeof result === "object" && // It is not a plain object
!Array.isArray(result) &&
"usage" in result &&
(0, isPlainObject_1.isPlainObject)(result.usage) &&
typeof result.usage.completionTokens === "number" &&
typeof result.usage.promptTokens === "number" &&
typeof result.usage.totalTokens === "number" &&
"response" in result &&
result.response &&
(0, isPlainObject_1.isPlainObject)(result.response) &&
typeof result.response.modelId === "string") {
return true;
}
return false;
}
getProviderFromArgs(args) {
if (!Array.isArray(args) || args.length === 0) {
return undefined;
}
const firstArg = args[0];
if (!(0, isPlainObject_1.isPlainObject)(firstArg)) {
return undefined;
}
if (!firstArg.model || typeof firstArg.model !== "object") {
return undefined;
}
if (!("provider" in firstArg.model) ||
typeof firstArg.model.provider !== "string") {
return undefined;
}
let providerName = firstArg.model.provider;
if (providerName.includes(".")) {
// e.g. google.generativeai
providerName = providerName.split(".")[0];
}
if (providerName === "amazon-bedrock") {
return "bedrock"; // Normalize amazon-bedrock to bedrock
}
if (providerName.includes("-")) {
// e.g. azure-openai
providerName = providerName.split("-")[0];
}
if (providerName === "google") {
return "gemini"; // Normalize google to gemini
}
return providerName;
}
getModelName(response) {
let modelName = response.response.modelId;
if (modelName.startsWith("models/")) {
modelName = modelName.slice(7); // Remove "models/" prefix
}
return modelName;
}
getInterceptors() {
return {
kind: "ai_op",
modifyReturnValue: (args, returnValue, agent) => {
if (returnValue instanceof Promise) {
// Inspect the response after the promise resolves, it won't change the original promise
returnValue.then((response) => {
try {
this.inspectAiCall(agent, args, response);
}
catch {
// If we don't catch these errors, it will result in an unhandled promise rejection!
}
});
}
return returnValue;
},
};
}
getStreamInterceptors() {
return {
kind: "ai_op",
modifyReturnValue: (args, returnValue, agent) => {
if (!returnValue ||
typeof returnValue !== "object" ||
!("response" in returnValue) ||
!(returnValue.response instanceof Promise) ||
!("usage" in returnValue) ||
!(returnValue.usage instanceof Promise)) {
return returnValue;
}
Promise.allSettled([returnValue.response, returnValue.usage]).then((promiseResults) => {
const response = promiseResults[0].status === "fulfilled"
? promiseResults[0].value
: undefined;
const usage = promiseResults[1].status === "fulfilled"
? promiseResults[1].value
: undefined;
if (!response || !usage) {
return;
}
try {
this.inspectAiCall(agent, args, {
response,
usage,
});
}
catch {
// If we don't catch these errors, it will result in an unhandled promise rejection!
}
});
return returnValue;
},
};
}
wrap(hooks) {
hooks
.addPackage("ai")
.withVersion("^4.0.0")
.onRequire((exports, pkgInfo) => {
// Can't wrap it directly because it's a readonly proxy
const generateTextFunc = exports.generateText;
const generateObjectFunc = exports.generateObject;
const streamTextFunc = exports.streamText;
const streamObjectFunc = exports.streamObject;
const interceptors = this.getInterceptors();
const streamInterceptors = this.getStreamInterceptors();
return {
...exports,
generateText: (0, wrapExport_1.wrapExport)(generateTextFunc, undefined, pkgInfo, interceptors),
generateObject: (0, wrapExport_1.wrapExport)(generateObjectFunc, undefined, pkgInfo, interceptors),
streamText: (0, wrapExport_1.wrapExport)(streamTextFunc, undefined, pkgInfo, streamInterceptors),
streamObject: (0, wrapExport_1.wrapExport)(streamObjectFunc, undefined, pkgInfo, streamInterceptors),
};
});
}
}
exports.AiSDK = AiSDK;