UNPKG

@tylercoles/mcp-rate-limit

Version:

Rate limiting utilities for MCP framework

503 lines 17.2 kB
/** * Memory-based rate limiter implementation */ export class MemoryRateLimiter { windows = new Map(); cleanupInterval = null; constructor(cleanupIntervalMs = 60000) { // Clean up expired windows periodically this.cleanupInterval = setInterval(() => { this.cleanupExpiredWindows(); }, cleanupIntervalMs); } async check(key, limit, windowMs) { const now = Date.now(); let window = this.windows.get(key); // Create new window if doesn't exist or is expired if (!window || now >= window.resetTime) { window = { requests: 0, resetTime: now + windowMs, windowMs, limit }; this.windows.set(key, window); } // Update window parameters if they changed if (window.limit !== limit || window.windowMs !== windowMs) { window.limit = limit; window.windowMs = windowMs; } // Check if request is allowed const allowed = window.requests < limit; if (allowed) { window.requests++; } const remaining = Math.max(0, limit - window.requests); const retryAfter = allowed ? undefined : Math.ceil((window.resetTime - now) / 1000); return { allowed, remaining, resetTime: window.resetTime, totalRequests: window.requests, retryAfter }; } async reset(key) { this.windows.delete(key); } async getStats(key) { const window = this.windows.get(key); if (!window) { return null; } return { key, requests: window.requests, remaining: Math.max(0, window.limit - window.requests), resetTime: window.resetTime, windowMs: window.windowMs, limit: window.limit }; } async cleanup() { if (this.cleanupInterval) { clearInterval(this.cleanupInterval); this.cleanupInterval = null; } this.windows.clear(); } cleanupExpiredWindows() { const now = Date.now(); for (const [key, window] of this.windows.entries()) { if (now >= window.resetTime) { this.windows.delete(key); } } } /** * Get current number of active windows (for monitoring) */ getActiveWindowCount() { return this.windows.size; } /** * Get all active keys (for monitoring) */ getActiveKeys() { return Array.from(this.windows.keys()); } } /** * Rate limiting utility functions */ export class RateLimitUtils { /** * Generate rate limit key from IP address */ static ipKey(ip, prefix = 'ip') { return `${prefix}:${ip}`; } /** * Generate rate limit key from session ID */ static sessionKey(sessionId, prefix = 'session') { return `${prefix}:${sessionId}`; } /** * Generate rate limit key from OAuth client ID */ static oauthClientKey(clientId, prefix = 'oauth') { return `${prefix}:${clientId}`; } /** * Generate rate limit key from user ID */ static userKey(userId, prefix = 'user') { return `${prefix}:${userId}`; } /** * Generate composite key from multiple components */ static compositeKey(...components) { return components.join(':'); } /** * Extract IP address from request with proxy support */ static getClientIp(req) { return req.ip || req.connection?.remoteAddress || req.socket?.remoteAddress || req.headers?.['x-forwarded-for']?.split(',')[0]?.trim() || '127.0.0.1'; } /** * Calculate retry-after header value in seconds */ static calculateRetryAfter(resetTime) { return Math.ceil((resetTime - Date.now()) / 1000); } /** * Format rate limit error message */ static formatErrorMessage(result) { const retryAfter = result.retryAfter || 0; return `Rate limit exceeded. Try again in ${retryAfter} seconds.`; } } /** * HTTP rate limit middleware factory */ export class HttpRateLimitMiddleware { store; config; constructor(config) { this.config = config; this.store = config.store || new MemoryRateLimiter(); } /** * Create Express middleware for global rate limiting */ createGlobalMiddleware() { if (!this.config.global) { return (req, res, next) => next(); } const { windowMs, maxRequests, skipSuccessfulRequests, skipFailedRequests } = this.config.global; return async (req, res, next) => { try { // Skip rate limiting for initialize requests to allow session creation if (req.body?.method === 'initialize') { next(); return; } const key = RateLimitUtils.ipKey(RateLimitUtils.getClientIp(req), 'global'); const result = await this.store.check(key, maxRequests, windowMs); if (!result.allowed) { return this.handleRateLimit(req, res, result); } // Track response status for skip logic if (skipSuccessfulRequests || skipFailedRequests) { const originalSend = res.send; res.send = function (data) { const statusCode = res.statusCode; const shouldSkip = (skipSuccessfulRequests && statusCode < 400) || (skipFailedRequests && statusCode >= 400); if (shouldSkip) { // Reverse the request count // Note: This is a simplification - in production you'd want more sophisticated tracking } return originalSend.call(this, data); }; } this.addHeaders(res, result, maxRequests); next(); } catch (error) { if (this.config.skipOnError) { next(); } else { next(error); } } }; } /** * Create Express middleware for per-client rate limiting */ createClientMiddleware() { if (!this.config.perClient) { return (req, res, next) => next(); } const { windowMs, maxRequests, keyGenerator } = this.config.perClient; return async (req, res, next) => { try { // Skip rate limiting for initialize requests to allow session creation if (req.body?.method === 'initialize') { next(); return; } let key; if (keyGenerator) { key = keyGenerator(req); } else { // Default key generation strategy if (req.user?.clientId) { // OAuth client rate limiting key = RateLimitUtils.oauthClientKey(req.user.clientId); } else if (req.headers['mcp-session-id']) { // Session-based rate limiting key = RateLimitUtils.sessionKey(req.headers['mcp-session-id']); } else { // IP-based rate limiting key = RateLimitUtils.ipKey(RateLimitUtils.getClientIp(req)); } } const result = await this.store.check(key, maxRequests, windowMs); if (!result.allowed) { return this.handleRateLimit(req, res, result); } this.addHeaders(res, result, maxRequests); next(); } catch (error) { if (this.config.skipOnError) { next(); } else { next(error); } } }; } /** * Handle rate limit exceeded */ handleRateLimit(req, res, result) { if (this.config.onLimitReached) { this.config.onLimitReached(req, res, result); return; } // Default rate limit response const retryAfter = result.retryAfter || RateLimitUtils.calculateRetryAfter(result.resetTime); this.addHeaders(res, result, 0); res.header('Retry-After', retryAfter.toString()); res.status(429).json({ jsonrpc: '2.0', error: { code: -32009, // Custom MCP rate limit error code message: RateLimitUtils.formatErrorMessage(result), data: { retryAfter, limit: result.totalRequests, remaining: result.remaining, resetTime: result.resetTime } }, id: req.body?.id || null }); } /** * Add rate limit headers to response */ addHeaders(res, result, limit) { if (!this.config.headers?.includeHeaders) { return; } const headerNames = this.config.headers.headerNames || { remaining: 'X-RateLimit-Remaining', reset: 'X-RateLimit-Reset', limit: 'X-RateLimit-Limit', retryAfter: 'Retry-After' }; res.header(headerNames.remaining || 'X-RateLimit-Remaining', result.remaining.toString()); res.header(headerNames.reset || 'X-RateLimit-Reset', Math.ceil(result.resetTime / 1000).toString()); res.header(headerNames.limit || 'X-RateLimit-Limit', limit.toString()); if (result.retryAfter) { res.header(headerNames.retryAfter || 'Retry-After', result.retryAfter.toString()); } } /** * Get rate limiter store */ getStore() { return this.store; } /** * Cleanup resources */ async cleanup() { await this.store.cleanup(); } } /** * WebSocket rate limit manager */ export class WebSocketRateLimitManager { store; config; connectionCounts = new Map(); activeConnections = new Map(); constructor(config) { this.config = config; this.store = config.store || new MemoryRateLimiter(); } /** * Check if new connection is allowed */ async checkConnectionLimit(ip, connection) { if (!this.config.connectionLimits) { return true; } // Check per-IP connection limit if (this.config.connectionLimits.perIp) { const { maxConnections, windowMs = 3600000 } = this.config.connectionLimits.perIp; const key = RateLimitUtils.ipKey(ip, 'ws-conn'); const result = await this.store.check(key, maxConnections, windowMs); if (!result.allowed) { return false; } } // Check global connection limit if (this.config.connectionLimits.global) { const { maxConnections } = this.config.connectionLimits.global; const totalConnections = Array.from(this.activeConnections.values()) .reduce((total, connections) => total + connections.length, 0); if (totalConnections >= maxConnections) { if (this.config.actions?.onConnectionLimitReached === 'close_oldest') { this.closeOldestConnection(); } else { return false; } } } // Track connection this.trackConnection(ip, connection); return true; } /** * Check if message is allowed */ async checkMessageLimit(connection, message) { if (!this.config.messageLimits) { return true; } const connectionId = this.getConnectionId(connection); // Check per-connection message limit if (this.config.messageLimits.perConnection) { const { maxMessages, windowMs } = this.config.messageLimits.perConnection; const key = RateLimitUtils.compositeKey('ws-msg', connectionId); const result = await this.store.check(key, maxMessages, windowMs); if (!result.allowed) { return this.handleMessageRateLimit(connection, result); } } // Check per-method message limit if (this.config.messageLimits.perMethod && message.method) { const methodLimits = this.config.messageLimits.perMethod[message.method]; if (methodLimits) { const { maxMessages, windowMs } = methodLimits; const key = RateLimitUtils.compositeKey('ws-method', connectionId, message.method); const result = await this.store.check(key, maxMessages, windowMs); if (!result.allowed) { return this.handleMessageRateLimit(connection, result); } } } return true; } /** * Handle message rate limit exceeded */ handleMessageRateLimit(connection, result) { const action = this.config.actions?.onMessageLimitReached || 'drop'; switch (action) { case 'drop': return false; case 'close_connection': if (typeof connection.close === 'function') { connection.close(1013, 'Message rate limit exceeded'); } return false; case 'throttle': // For throttling, we could implement a delay mechanism // For now, just drop the message return false; default: return false; } } /** * Track new connection */ trackConnection(ip, connection) { if (!this.activeConnections.has(ip)) { this.activeConnections.set(ip, []); } this.activeConnections.get(ip).push(connection); // Update count const count = this.connectionCounts.get(ip) || 0; this.connectionCounts.set(ip, count + 1); // Setup cleanup on connection close const cleanup = () => this.untrackConnection(ip, connection); if (typeof connection.on === 'function') { connection.on('close', cleanup); connection.on('error', cleanup); } } /** * Untrack closed connection */ untrackConnection(ip, connection) { const connections = this.activeConnections.get(ip); if (connections) { const index = connections.indexOf(connection); if (index >= 0) { connections.splice(index, 1); if (connections.length === 0) { this.activeConnections.delete(ip); this.connectionCounts.delete(ip); } else { this.connectionCounts.set(ip, connections.length); } } } } /** * Close oldest connection for global limit enforcement */ closeOldestConnection() { let oldestConnection = null; let oldestTime = Date.now(); for (const connections of this.activeConnections.values()) { for (const connection of connections) { // Assuming connections have a createdAt timestamp const createdAt = connection.createdAt || 0; if (createdAt < oldestTime) { oldestTime = createdAt; oldestConnection = connection; } } } if (oldestConnection && typeof oldestConnection.close === 'function') { oldestConnection.close(1013, 'Connection limit exceeded - closing oldest'); } } /** * Get unique connection identifier */ getConnectionId(connection) { return connection.id || connection._id || connection.remoteAddress || Math.random().toString(36); } /** * Get connection statistics */ getConnectionStats() { const totalConnections = Array.from(this.connectionCounts.values()) .reduce((total, count) => total + count, 0); const connectionsPerIp = {}; for (const [ip, count] of this.connectionCounts.entries()) { connectionsPerIp[ip] = count; } return { totalConnections, connectionsPerIp }; } /** * Cleanup resources */ async cleanup() { await this.store.cleanup(); this.connectionCounts.clear(); this.activeConnections.clear(); } } // All exports are defined above in this file //# sourceMappingURL=index.js.map