UNPKG

@fastify/rate-limit

Version:

A low overhead rate limiter for your routes

343 lines (286 loc) 11 kB
'use strict' const fp = require('fastify-plugin') const { parse, format } = require('@lukeed/ms') const LocalStore = require('./store/LocalStore') const RedisStore = require('./store/RedisStore') const defaultMax = 1000 const defaultTimeWindow = 60000 const defaultHook = 'onRequest' const defaultHeaders = { rateLimit: 'x-ratelimit-limit', rateRemaining: 'x-ratelimit-remaining', rateReset: 'x-ratelimit-reset', retryAfter: 'retry-after' } const draftSpecHeaders = { rateLimit: 'ratelimit-limit', rateRemaining: 'ratelimit-remaining', rateReset: 'ratelimit-reset', retryAfter: 'retry-after' } const defaultOnFn = () => {} const defaultKeyGenerator = (req) => req.ip const defaultErrorResponse = (_req, context) => { const err = new Error(`Rate limit exceeded, retry in ${context.after}`) err.statusCode = context.statusCode return err } async function fastifyRateLimit (fastify, settings) { const globalParams = { global: (typeof settings.global === 'boolean') ? settings.global : true } if (typeof settings.enableDraftSpec === 'boolean' && settings.enableDraftSpec) { globalParams.enableDraftSpec = true globalParams.labels = draftSpecHeaders } else { globalParams.enableDraftSpec = false globalParams.labels = defaultHeaders } globalParams.addHeaders = Object.assign({ [globalParams.labels.rateLimit]: true, [globalParams.labels.rateRemaining]: true, [globalParams.labels.rateReset]: true, [globalParams.labels.retryAfter]: true }, settings.addHeaders) globalParams.addHeadersOnExceeding = Object.assign({ [globalParams.labels.rateLimit]: true, [globalParams.labels.rateRemaining]: true, [globalParams.labels.rateReset]: true }, settings.addHeadersOnExceeding) // Global maximum allowed requests if (Number.isFinite(settings.max) && settings.max >= 0) { globalParams.max = Math.trunc(settings.max) } else if ( typeof settings.max === 'function' ) { globalParams.max = settings.max } else { globalParams.max = defaultMax } // Global time window if (Number.isFinite(settings.timeWindow) && settings.timeWindow >= 0) { globalParams.timeWindow = Math.trunc(settings.timeWindow) } else if (typeof settings.timeWindow === 'string') { globalParams.timeWindow = parse(settings.timeWindow) } else if ( typeof settings.timeWindow === 'function' ) { globalParams.timeWindow = settings.timeWindow } else { globalParams.timeWindow = defaultTimeWindow } globalParams.hook = settings.hook || defaultHook globalParams.allowList = settings.allowList || settings.whitelist || null globalParams.ban = Number.isFinite(settings.ban) && settings.ban >= 0 ? Math.trunc(settings.ban) : -1 globalParams.onBanReach = typeof settings.onBanReach === 'function' ? settings.onBanReach : defaultOnFn globalParams.onExceeding = typeof settings.onExceeding === 'function' ? settings.onExceeding : defaultOnFn globalParams.onExceeded = typeof settings.onExceeded === 'function' ? settings.onExceeded : defaultOnFn globalParams.continueExceeding = typeof settings.continueExceeding === 'boolean' ? settings.continueExceeding : false globalParams.exponentialBackoff = typeof settings.exponentialBackoff === 'boolean' ? settings.exponentialBackoff : false globalParams.keyGenerator = typeof settings.keyGenerator === 'function' ? settings.keyGenerator : defaultKeyGenerator if (typeof settings.errorResponseBuilder === 'function') { globalParams.errorResponseBuilder = settings.errorResponseBuilder globalParams.isCustomErrorMessage = true } else { globalParams.errorResponseBuilder = defaultErrorResponse globalParams.isCustomErrorMessage = false } globalParams.skipOnError = typeof settings.skipOnError === 'boolean' ? settings.skipOnError : false const pluginComponent = { rateLimitRan: Symbol('fastify.request.rateLimitRan'), store: null } if (settings.store) { const Store = settings.store pluginComponent.store = new Store(globalParams) } else { if (settings.redis) { pluginComponent.store = new RedisStore(globalParams.continueExceeding, globalParams.exponentialBackoff, settings.redis, settings.nameSpace) } else { pluginComponent.store = new LocalStore(globalParams.continueExceeding, globalParams.exponentialBackoff, settings.cache) } } fastify.decorateRequest(pluginComponent.rateLimitRan, false) if (!fastify.hasDecorator('createRateLimit')) { fastify.decorate('createRateLimit', (options) => { const args = createLimiterArgs(pluginComponent, globalParams, options) return (req) => applyRateLimit.apply(this, args.concat(req)) }) } if (!fastify.hasDecorator('rateLimit')) { fastify.decorate('rateLimit', (options) => { const args = createLimiterArgs(pluginComponent, globalParams, options) return rateLimitRequestHandler(...args) }) } fastify.addHook('onRoute', (routeOptions) => { if (routeOptions.config?.rateLimit != null) { if (typeof routeOptions.config.rateLimit === 'object') { const newPluginComponent = Object.create(pluginComponent) const mergedRateLimitParams = mergeParams(globalParams, routeOptions.config.rateLimit, { routeInfo: routeOptions }) newPluginComponent.store = pluginComponent.store.child(mergedRateLimitParams) addRouteRateHook(newPluginComponent, mergedRateLimitParams, routeOptions) } else if (routeOptions.config.rateLimit !== false) { throw new Error('Unknown value for route rate-limit configuration') } } else if (globalParams.global) { // As the endpoint does not have a custom configuration, use the global one addRouteRateHook(pluginComponent, globalParams, routeOptions) } }) } function mergeParams (...params) { const result = Object.assign({}, ...params) if (Number.isFinite(result.timeWindow) && result.timeWindow >= 0) { result.timeWindow = Math.trunc(result.timeWindow) } else if (typeof result.timeWindow === 'string') { result.timeWindow = parse(result.timeWindow) } else if (typeof result.timeWindow !== 'function') { result.timeWindow = defaultTimeWindow } if (Number.isFinite(result.max) && result.max >= 0) { result.max = Math.trunc(result.max) } else if (typeof result.max !== 'function') { result.max = defaultMax } if (Number.isFinite(result.ban) && result.ban >= 0) { result.ban = Math.trunc(result.ban) } else { result.ban = -1 } if (result.groupId !== undefined && typeof result.groupId !== 'string') { throw new Error('groupId must be a string') } return result } function createLimiterArgs (pluginComponent, globalParams, options) { if (typeof options === 'object') { const newPluginComponent = Object.create(pluginComponent) const mergedRateLimitParams = mergeParams(globalParams, options, { routeInfo: {} }) newPluginComponent.store = newPluginComponent.store.child(mergedRateLimitParams) return [newPluginComponent, mergedRateLimitParams] } return [pluginComponent, globalParams] } function addRouteRateHook (pluginComponent, params, routeOptions) { const hook = params.hook const hookHandler = rateLimitRequestHandler(pluginComponent, params) if (Array.isArray(routeOptions[hook])) { routeOptions[hook].push(hookHandler) } else if (typeof routeOptions[hook] === 'function') { routeOptions[hook] = [routeOptions[hook], hookHandler] } else { routeOptions[hook] = [hookHandler] } } async function applyRateLimit (pluginComponent, params, req) { const { store } = pluginComponent // Retrieve the key from the generator (the global one or the one defined in the endpoint) let key = await params.keyGenerator(req) const groupId = req.routeOptions.config?.rateLimit?.groupId if (groupId) { key += groupId } // Don't apply any rate limiting if in the allow list if (params.allowList) { if (typeof params.allowList === 'function') { if (await params.allowList(req, key)) { return { isAllowed: true, key } } } else if (params.allowList.indexOf(key) !== -1) { return { isAllowed: true, key } } } const max = typeof params.max === 'number' ? params.max : await params.max(req, key) const timeWindow = typeof params.timeWindow === 'number' ? params.timeWindow : await params.timeWindow(req, key) let current = 0 let ttl = 0 let ttlInSeconds = 0 // We increment the rate limit for the current request try { const res = await new Promise((resolve, reject) => { store.incr(key, (err, res) => { err ? reject(err) : resolve(res) }, timeWindow, max) }) current = res.current ttl = res.ttl ttlInSeconds = Math.ceil(res.ttl / 1000) } catch (err) { if (!params.skipOnError) { throw err } } return { isAllowed: false, key, max, timeWindow, remaining: Math.max(0, max - current), ttl, ttlInSeconds, isExceeded: current > max, isBanned: params.ban !== -1 && current - max > params.ban } } function rateLimitRequestHandler (pluginComponent, params) { const { rateLimitRan } = pluginComponent return async (req, res) => { if (req[rateLimitRan]) { return } req[rateLimitRan] = true const rateLimit = await applyRateLimit(pluginComponent, params, req) if (rateLimit.isAllowed) { return } const { key, max, remaining, ttl, ttlInSeconds, isExceeded, isBanned } = rateLimit if (!isExceeded) { if (params.addHeadersOnExceeding[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) } if (params.addHeadersOnExceeding[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, remaining) } if (params.addHeadersOnExceeding[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) } params.onExceeding(req, key) return } params.onExceeded(req, key) if (params.addHeaders[params.labels.rateLimit]) { res.header(params.labels.rateLimit, max) } if (params.addHeaders[params.labels.rateRemaining]) { res.header(params.labels.rateRemaining, 0) } if (params.addHeaders[params.labels.rateReset]) { res.header(params.labels.rateReset, ttlInSeconds) } if (params.addHeaders[params.labels.retryAfter]) { res.header(params.labels.retryAfter, ttlInSeconds) } const respCtx = { statusCode: 429, ban: false, max, ttl, after: format(ttlInSeconds * 1000, true) } if (isBanned) { respCtx.statusCode = 403 respCtx.ban = true params.onBanReach(req, key) } throw params.errorResponseBuilder(req, respCtx) } } module.exports = fp(fastifyRateLimit, { fastify: '5.x', name: '@fastify/rate-limit' }) module.exports.default = fastifyRateLimit module.exports.fastifyRateLimit = fastifyRateLimit