@tylercoles/mcp-rate-limit
Version:
Rate limiting utilities for MCP framework
503 lines • 17.2 kB
JavaScript
/**
* Memory-based rate limiter implementation
*/
export class MemoryRateLimiter {
windows = new Map();
cleanupInterval = null;
constructor(cleanupIntervalMs = 60000) {
// Clean up expired windows periodically
this.cleanupInterval = setInterval(() => {
this.cleanupExpiredWindows();
}, cleanupIntervalMs);
}
async check(key, limit, windowMs) {
const now = Date.now();
let window = this.windows.get(key);
// Create new window if doesn't exist or is expired
if (!window || now >= window.resetTime) {
window = {
requests: 0,
resetTime: now + windowMs,
windowMs,
limit
};
this.windows.set(key, window);
}
// Update window parameters if they changed
if (window.limit !== limit || window.windowMs !== windowMs) {
window.limit = limit;
window.windowMs = windowMs;
}
// Check if request is allowed
const allowed = window.requests < limit;
if (allowed) {
window.requests++;
}
const remaining = Math.max(0, limit - window.requests);
const retryAfter = allowed ? undefined : Math.ceil((window.resetTime - now) / 1000);
return {
allowed,
remaining,
resetTime: window.resetTime,
totalRequests: window.requests,
retryAfter
};
}
async reset(key) {
this.windows.delete(key);
}
async getStats(key) {
const window = this.windows.get(key);
if (!window) {
return null;
}
return {
key,
requests: window.requests,
remaining: Math.max(0, window.limit - window.requests),
resetTime: window.resetTime,
windowMs: window.windowMs,
limit: window.limit
};
}
async cleanup() {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
this.windows.clear();
}
cleanupExpiredWindows() {
const now = Date.now();
for (const [key, window] of this.windows.entries()) {
if (now >= window.resetTime) {
this.windows.delete(key);
}
}
}
/**
* Get current number of active windows (for monitoring)
*/
getActiveWindowCount() {
return this.windows.size;
}
/**
* Get all active keys (for monitoring)
*/
getActiveKeys() {
return Array.from(this.windows.keys());
}
}
/**
* Rate limiting utility functions
*/
export class RateLimitUtils {
/**
* Generate rate limit key from IP address
*/
static ipKey(ip, prefix = 'ip') {
return `${prefix}:${ip}`;
}
/**
* Generate rate limit key from session ID
*/
static sessionKey(sessionId, prefix = 'session') {
return `${prefix}:${sessionId}`;
}
/**
* Generate rate limit key from OAuth client ID
*/
static oauthClientKey(clientId, prefix = 'oauth') {
return `${prefix}:${clientId}`;
}
/**
* Generate rate limit key from user ID
*/
static userKey(userId, prefix = 'user') {
return `${prefix}:${userId}`;
}
/**
* Generate composite key from multiple components
*/
static compositeKey(...components) {
return components.join(':');
}
/**
* Extract IP address from request with proxy support
*/
static getClientIp(req) {
return req.ip ||
req.connection?.remoteAddress ||
req.socket?.remoteAddress ||
req.headers?.['x-forwarded-for']?.split(',')[0]?.trim() ||
'127.0.0.1';
}
/**
* Calculate retry-after header value in seconds
*/
static calculateRetryAfter(resetTime) {
return Math.ceil((resetTime - Date.now()) / 1000);
}
/**
* Format rate limit error message
*/
static formatErrorMessage(result) {
const retryAfter = result.retryAfter || 0;
return `Rate limit exceeded. Try again in ${retryAfter} seconds.`;
}
}
/**
* HTTP rate limit middleware factory
*/
export class HttpRateLimitMiddleware {
store;
config;
constructor(config) {
this.config = config;
this.store = config.store || new MemoryRateLimiter();
}
/**
* Create Express middleware for global rate limiting
*/
createGlobalMiddleware() {
if (!this.config.global) {
return (req, res, next) => next();
}
const { windowMs, maxRequests, skipSuccessfulRequests, skipFailedRequests } = this.config.global;
return async (req, res, next) => {
try {
// Skip rate limiting for initialize requests to allow session creation
if (req.body?.method === 'initialize') {
next();
return;
}
const key = RateLimitUtils.ipKey(RateLimitUtils.getClientIp(req), 'global');
const result = await this.store.check(key, maxRequests, windowMs);
if (!result.allowed) {
return this.handleRateLimit(req, res, result);
}
// Track response status for skip logic
if (skipSuccessfulRequests || skipFailedRequests) {
const originalSend = res.send;
res.send = function (data) {
const statusCode = res.statusCode;
const shouldSkip = (skipSuccessfulRequests && statusCode < 400) ||
(skipFailedRequests && statusCode >= 400);
if (shouldSkip) {
// Reverse the request count
// Note: This is a simplification - in production you'd want more sophisticated tracking
}
return originalSend.call(this, data);
};
}
this.addHeaders(res, result, maxRequests);
next();
}
catch (error) {
if (this.config.skipOnError) {
next();
}
else {
next(error);
}
}
};
}
/**
* Create Express middleware for per-client rate limiting
*/
createClientMiddleware() {
if (!this.config.perClient) {
return (req, res, next) => next();
}
const { windowMs, maxRequests, keyGenerator } = this.config.perClient;
return async (req, res, next) => {
try {
// Skip rate limiting for initialize requests to allow session creation
if (req.body?.method === 'initialize') {
next();
return;
}
let key;
if (keyGenerator) {
key = keyGenerator(req);
}
else {
// Default key generation strategy
if (req.user?.clientId) {
// OAuth client rate limiting
key = RateLimitUtils.oauthClientKey(req.user.clientId);
}
else if (req.headers['mcp-session-id']) {
// Session-based rate limiting
key = RateLimitUtils.sessionKey(req.headers['mcp-session-id']);
}
else {
// IP-based rate limiting
key = RateLimitUtils.ipKey(RateLimitUtils.getClientIp(req));
}
}
const result = await this.store.check(key, maxRequests, windowMs);
if (!result.allowed) {
return this.handleRateLimit(req, res, result);
}
this.addHeaders(res, result, maxRequests);
next();
}
catch (error) {
if (this.config.skipOnError) {
next();
}
else {
next(error);
}
}
};
}
/**
* Handle rate limit exceeded
*/
handleRateLimit(req, res, result) {
if (this.config.onLimitReached) {
this.config.onLimitReached(req, res, result);
return;
}
// Default rate limit response
const retryAfter = result.retryAfter || RateLimitUtils.calculateRetryAfter(result.resetTime);
this.addHeaders(res, result, 0);
res.header('Retry-After', retryAfter.toString());
res.status(429).json({
jsonrpc: '2.0',
error: {
code: -32009, // Custom MCP rate limit error code
message: RateLimitUtils.formatErrorMessage(result),
data: {
retryAfter,
limit: result.totalRequests,
remaining: result.remaining,
resetTime: result.resetTime
}
},
id: req.body?.id || null
});
}
/**
* Add rate limit headers to response
*/
addHeaders(res, result, limit) {
if (!this.config.headers?.includeHeaders) {
return;
}
const headerNames = this.config.headers.headerNames || {
remaining: 'X-RateLimit-Remaining',
reset: 'X-RateLimit-Reset',
limit: 'X-RateLimit-Limit',
retryAfter: 'Retry-After'
};
res.header(headerNames.remaining || 'X-RateLimit-Remaining', result.remaining.toString());
res.header(headerNames.reset || 'X-RateLimit-Reset', Math.ceil(result.resetTime / 1000).toString());
res.header(headerNames.limit || 'X-RateLimit-Limit', limit.toString());
if (result.retryAfter) {
res.header(headerNames.retryAfter || 'Retry-After', result.retryAfter.toString());
}
}
/**
* Get rate limiter store
*/
getStore() {
return this.store;
}
/**
* Cleanup resources
*/
async cleanup() {
await this.store.cleanup();
}
}
/**
* WebSocket rate limit manager
*/
export class WebSocketRateLimitManager {
store;
config;
connectionCounts = new Map();
activeConnections = new Map();
constructor(config) {
this.config = config;
this.store = config.store || new MemoryRateLimiter();
}
/**
* Check if new connection is allowed
*/
async checkConnectionLimit(ip, connection) {
if (!this.config.connectionLimits) {
return true;
}
// Check per-IP connection limit
if (this.config.connectionLimits.perIp) {
const { maxConnections, windowMs = 3600000 } = this.config.connectionLimits.perIp;
const key = RateLimitUtils.ipKey(ip, 'ws-conn');
const result = await this.store.check(key, maxConnections, windowMs);
if (!result.allowed) {
return false;
}
}
// Check global connection limit
if (this.config.connectionLimits.global) {
const { maxConnections } = this.config.connectionLimits.global;
const totalConnections = Array.from(this.activeConnections.values())
.reduce((total, connections) => total + connections.length, 0);
if (totalConnections >= maxConnections) {
if (this.config.actions?.onConnectionLimitReached === 'close_oldest') {
this.closeOldestConnection();
}
else {
return false;
}
}
}
// Track connection
this.trackConnection(ip, connection);
return true;
}
/**
* Check if message is allowed
*/
async checkMessageLimit(connection, message) {
if (!this.config.messageLimits) {
return true;
}
const connectionId = this.getConnectionId(connection);
// Check per-connection message limit
if (this.config.messageLimits.perConnection) {
const { maxMessages, windowMs } = this.config.messageLimits.perConnection;
const key = RateLimitUtils.compositeKey('ws-msg', connectionId);
const result = await this.store.check(key, maxMessages, windowMs);
if (!result.allowed) {
return this.handleMessageRateLimit(connection, result);
}
}
// Check per-method message limit
if (this.config.messageLimits.perMethod && message.method) {
const methodLimits = this.config.messageLimits.perMethod[message.method];
if (methodLimits) {
const { maxMessages, windowMs } = methodLimits;
const key = RateLimitUtils.compositeKey('ws-method', connectionId, message.method);
const result = await this.store.check(key, maxMessages, windowMs);
if (!result.allowed) {
return this.handleMessageRateLimit(connection, result);
}
}
}
return true;
}
/**
* Handle message rate limit exceeded
*/
handleMessageRateLimit(connection, result) {
const action = this.config.actions?.onMessageLimitReached || 'drop';
switch (action) {
case 'drop':
return false;
case 'close_connection':
if (typeof connection.close === 'function') {
connection.close(1013, 'Message rate limit exceeded');
}
return false;
case 'throttle':
// For throttling, we could implement a delay mechanism
// For now, just drop the message
return false;
default:
return false;
}
}
/**
* Track new connection
*/
trackConnection(ip, connection) {
if (!this.activeConnections.has(ip)) {
this.activeConnections.set(ip, []);
}
this.activeConnections.get(ip).push(connection);
// Update count
const count = this.connectionCounts.get(ip) || 0;
this.connectionCounts.set(ip, count + 1);
// Setup cleanup on connection close
const cleanup = () => this.untrackConnection(ip, connection);
if (typeof connection.on === 'function') {
connection.on('close', cleanup);
connection.on('error', cleanup);
}
}
/**
* Untrack closed connection
*/
untrackConnection(ip, connection) {
const connections = this.activeConnections.get(ip);
if (connections) {
const index = connections.indexOf(connection);
if (index >= 0) {
connections.splice(index, 1);
if (connections.length === 0) {
this.activeConnections.delete(ip);
this.connectionCounts.delete(ip);
}
else {
this.connectionCounts.set(ip, connections.length);
}
}
}
}
/**
* Close oldest connection for global limit enforcement
*/
closeOldestConnection() {
let oldestConnection = null;
let oldestTime = Date.now();
for (const connections of this.activeConnections.values()) {
for (const connection of connections) {
// Assuming connections have a createdAt timestamp
const createdAt = connection.createdAt || 0;
if (createdAt < oldestTime) {
oldestTime = createdAt;
oldestConnection = connection;
}
}
}
if (oldestConnection && typeof oldestConnection.close === 'function') {
oldestConnection.close(1013, 'Connection limit exceeded - closing oldest');
}
}
/**
* Get unique connection identifier
*/
getConnectionId(connection) {
return connection.id ||
connection._id ||
connection.remoteAddress ||
Math.random().toString(36);
}
/**
* Get connection statistics
*/
getConnectionStats() {
const totalConnections = Array.from(this.connectionCounts.values())
.reduce((total, count) => total + count, 0);
const connectionsPerIp = {};
for (const [ip, count] of this.connectionCounts.entries()) {
connectionsPerIp[ip] = count;
}
return { totalConnections, connectionsPerIp };
}
/**
* Cleanup resources
*/
async cleanup() {
await this.store.cleanup();
this.connectionCounts.clear();
this.activeConnections.clear();
}
}
// All exports are defined above in this file
//# sourceMappingURL=index.js.map