@kya-os/agentshield-nextjs
Version:
Next.js middleware for AgentShield AI agent detection
381 lines (376 loc) • 12.6 kB
JavaScript
;
var server = require('next/server');
// src/api-middleware.ts
// src/api-client.ts
var DEFAULT_BASE_URL = "https://kya.vouched.id";
var EDGE_DETECT_URL = "https://detect.checkpoint-gateway.ai";
var DEFAULT_TIMEOUT = 5e3;
var AgentShieldClient = class {
apiKey;
baseUrl;
useEdge;
timeout;
debug;
constructor(config) {
if (!config.apiKey) {
throw new Error("AgentShield API key is required");
}
this.apiKey = config.apiKey;
this.useEdge = config.useEdge !== false;
this.baseUrl = config.baseUrl || (this.useEdge ? EDGE_DETECT_URL : DEFAULT_BASE_URL);
this.timeout = config.timeout || DEFAULT_TIMEOUT;
this.debug = config.debug || false;
}
/**
* Call the enforce API to check if a request should be allowed
*/
async enforce(input) {
const startTime = Date.now();
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.timeout);
try {
const endpoint = this.useEdge ? `${this.baseUrl}/__detect/enforce` : `${this.baseUrl}/api/v1/enforce`;
const response = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"X-Request-ID": input.requestId || crypto.randomUUID()
},
body: JSON.stringify(input),
signal: controller.signal
});
clearTimeout(timeoutId);
const data = await response.json();
if (this.debug) {
console.log("[AgentShield] Enforce response:", {
status: response.status,
action: data.data?.decision.action,
processingTimeMs: Date.now() - startTime
});
}
if (!response.ok) {
return {
success: false,
error: {
code: `HTTP_${response.status}`,
message: data.error?.message || `HTTP error: ${response.status}`
}
};
}
return data;
} catch (error) {
clearTimeout(timeoutId);
throw error;
}
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
if (this.debug) {
console.warn("[AgentShield] Request timed out");
}
return {
success: false,
error: {
code: "TIMEOUT",
message: `Request timed out after ${this.timeout}ms`
}
};
}
if (this.debug) {
console.error("[AgentShield] Request failed:", error);
}
return {
success: false,
error: {
code: "NETWORK_ERROR",
message: error instanceof Error ? error.message : "Network request failed"
}
};
}
}
/**
* Quick check - returns just the action without full response parsing
* Useful for very fast middleware that just needs allow/block
*/
async quickCheck(input) {
const result = await this.enforce(input);
if (!result.success || !result.data) {
return {
action: "allow",
error: result.error?.message
};
}
return {
action: result.data.decision.action
};
}
/**
* Check if this client is using edge detection (Gateway Worker)
*/
isUsingEdge() {
return this.useEdge;
}
/**
* Log a detection result to AgentShield database.
* Use after Gateway Worker detection to persist results.
* Fire-and-forget - returns immediately without waiting for DB write.
*
* @example
* ```typescript
* // After receiving Gateway response
* if (client.isUsingEdge() && response.data?.detection) {
* client.logDetection({
* detection: response.data.detection,
* context: { userAgent, ipAddress, path, url, method }
* }).catch(err => console.error('Log failed:', err));
* }
* ```
*/
async logDetection(input) {
const logEndpoint = this.useEdge ? `${DEFAULT_BASE_URL}/api/v1/log-detection` : `${this.baseUrl}/api/v1/log-detection`;
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.timeout);
try {
const response = await fetch(logEndpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`
},
body: JSON.stringify({
detection: {
isAgent: input.detection.isAgent,
confidence: input.detection.confidence,
agentName: input.detection.agentName,
agentType: input.detection.agentType,
detectionClass: input.detection.detectionClass,
verificationMethod: input.detection.verificationMethod,
reasons: input.detection.reasons
},
context: input.context,
source: input.source || "gateway"
}),
signal: controller.signal
});
clearTimeout(timeoutId);
if (!response.ok && this.debug) {
console.warn("[AgentShield] Log detection returned non-2xx:", response.status);
}
} catch (error) {
clearTimeout(timeoutId);
throw error;
}
} catch (error) {
if (this.debug) {
console.error("[AgentShield] Log detection failed:", error);
}
throw error;
}
}
};
var clientInstance = null;
function getAgentShieldClient(config) {
if (!clientInstance) {
const apiKey = config?.apiKey || process.env.AGENTSHIELD_API_KEY;
if (!apiKey) {
throw new Error(
"AgentShield API key is required. Set AGENTSHIELD_API_KEY environment variable or pass apiKey in config."
);
}
clientInstance = new AgentShieldClient({
apiKey,
baseUrl: config?.baseUrl || process.env.AGENTSHIELD_API_URL,
// Default to edge detection unless explicitly disabled
useEdge: config?.useEdge ?? process.env.AGENTSHIELD_USE_EDGE !== "false",
timeout: config?.timeout,
debug: config?.debug || process.env.AGENTSHIELD_DEBUG === "true"
});
}
return clientInstance;
}
// src/api-middleware.ts
function matchPath(path, pattern) {
if (pattern === path) return true;
if (pattern.includes("*")) {
const regexPattern = pattern.replace(/[.+?^${}()|[\]\\]/g, "\\$&").replace(/\*/g, ".*");
return new RegExp(`^${regexPattern}$`).test(path);
}
if (pattern.endsWith("/")) {
return path.startsWith(pattern) || path === pattern.slice(0, -1);
}
return path.startsWith(pattern);
}
function shouldSkipPath(path, skipPaths) {
return skipPaths.some((pattern) => matchPath(path, pattern));
}
function shouldIncludePath(path, includePaths) {
if (!includePaths || includePaths.length === 0) return true;
return includePaths.some((pattern) => matchPath(path, pattern));
}
function buildBlockedResponse(decision, config) {
const status = config.blockedResponse?.status ?? 403;
const message = config.blockedResponse?.message ?? decision.message ?? "Access denied";
const response = server.NextResponse.json(
{
error: message,
code: "AGENT_BLOCKED",
reason: decision.reason,
agentType: decision.agentType
},
{ status }
);
if (config.blockedResponse?.headers) {
for (const [key, value] of Object.entries(config.blockedResponse.headers)) {
response.headers.set(key, value);
}
}
response.headers.set("X-AgentShield-Action", decision.action);
response.headers.set("X-AgentShield-Reason", decision.reason);
return response;
}
function buildRedirectResponse(request, decision, config) {
const redirectUrl = config.redirectUrl || decision.redirectUrl || "/blocked";
const url = new URL(redirectUrl, request.url);
url.searchParams.set("reason", decision.reason);
if (decision.agentType) {
url.searchParams.set("agent", decision.agentType);
}
return server.NextResponse.redirect(url);
}
function withAgentShield(config = {}) {
let client = null;
const getClient = () => {
if (!client) {
client = getAgentShieldClient({
apiKey: config.apiKey,
baseUrl: config.apiUrl,
useEdge: config.useEdge,
timeout: config.timeout,
debug: config.debug
});
}
return client;
};
const defaultSkipPaths = [
"/_next/static/*",
"/_next/image/*",
"/favicon.ico",
"/robots.txt",
"/sitemap.xml"
];
const skipPaths = [...defaultSkipPaths, ...config.skipPaths || []];
const failOpen = config.failOpen ?? true;
return async function middleware(request) {
const path = request.nextUrl.pathname;
const startTime = Date.now();
if (shouldSkipPath(path, skipPaths)) {
return server.NextResponse.next();
}
if (!shouldIncludePath(path, config.includePaths)) {
return server.NextResponse.next();
}
try {
const client2 = getClient();
const userAgent = request.headers.get("user-agent") || void 0;
const ipAddress = request.ip || request.headers.get("x-forwarded-for")?.split(",")[0]?.trim() || request.headers.get("x-real-ip") || void 0;
const result = await client2.enforce({
headers: Object.fromEntries(request.headers.entries()),
userAgent,
ipAddress,
path,
url: request.url,
method: request.method,
requestId: request.headers.get("x-request-id") || void 0,
options: {
// Always include detection results for logging (needed when using edge)
includeDetectionResult: true
}
});
if (!result.success || !result.data) {
if (config.debug) {
console.warn("[AgentShield] API error:", result.error);
}
if (failOpen) {
return server.NextResponse.next();
}
return server.NextResponse.json(
{ error: "Security check failed", code: "API_ERROR" },
{ status: 503 }
);
}
const decision = result.data.decision;
if (config.debug) {
console.log("[AgentShield] Decision:", {
path,
action: decision.action,
isAgent: decision.isAgent,
confidence: decision.confidence,
agentName: decision.agentName,
detectionMethod: result.data.detection?.detectionMethod || "not-included",
processingTimeMs: Date.now() - startTime
});
}
if (client2.isUsingEdge() && result.data.detection) {
client2.logDetection({
detection: result.data.detection,
context: { userAgent, ipAddress, path, url: request.url, method: request.method }
}).catch((err) => {
if (config.debug) {
console.error("[AgentShield] Log detection failed:", err);
}
});
}
if (decision.isAgent && config.onAgentDetected) {
await config.onAgentDetected(request, decision);
}
switch (decision.action) {
case "block": {
if (config.customBlockedResponse) {
return config.customBlockedResponse(request, decision);
}
if (config.onBlock === "redirect") {
return buildRedirectResponse(request, decision, config);
}
return buildBlockedResponse(decision, config);
}
case "redirect": {
return buildRedirectResponse(request, decision, config);
}
case "challenge": {
return buildRedirectResponse(request, decision, config);
}
case "log":
case "allow":
default: {
const response = server.NextResponse.next();
if (decision.isAgent) {
response.headers.set("X-AgentShield-Detected", "true");
response.headers.set("X-AgentShield-Confidence", decision.confidence.toString());
if (decision.agentName) {
response.headers.set("X-AgentShield-Agent", decision.agentName);
}
}
return response;
}
}
} catch (error) {
if (config.debug) {
console.error("[AgentShield] Middleware error:", error);
}
if (failOpen) {
return server.NextResponse.next();
}
return server.NextResponse.json(
{ error: "Security check failed", code: "MIDDLEWARE_ERROR" },
{ status: 503 }
);
}
};
}
var agentShieldMiddleware = withAgentShield();
exports.agentShieldMiddleware = agentShieldMiddleware;
exports.withAgentShield = withAgentShield;
//# sourceMappingURL=api-middleware.js.map
//# sourceMappingURL=api-middleware.js.map