@convex-dev/rate-limiter
Version:
A rate limiter component for Convex. Define and use application-layer rate limits. Type-safe, transactional, fair, safe, and configurable sharding to scale.
142 lines (130 loc) • 3.96 kB
text/typescript
import { v } from "convex/values";
import { mutation, query } from "./_generated/server.js";
import {
calculateRateLimit,
getValueReturns,
rateLimitArgs,
configValidator,
rateLimitReturns,
type GetValueReturns,
} from "../shared.js";
import {
checkRateLimitOrThrow,
configWithDefaults,
getShard,
} from "./internal.js";
import { api } from "./_generated/api.js";
export const rateLimit = mutation({
args: rateLimitArgs,
returns: rateLimitReturns,
handler: async (ctx, args) => {
const { status, updates } = await checkRateLimitOrThrow(ctx.db, args);
for (const { value, ts, existing, shard } of updates) {
if (existing) {
await ctx.db.patch(existing._id, { ts, value });
} else {
const { name, key: optionalKey } = args;
const key = optionalKey;
await ctx.db.insert("rateLimits", { name, key, ts, value, shard });
}
}
return status;
},
});
export const checkRateLimit = query({
args: rateLimitArgs,
returns: rateLimitReturns,
handler: async (ctx, args) => {
const { status } = await checkRateLimitOrThrow(ctx.db, args);
return status;
},
});
export const getValue = query({
args: {
name: v.string(),
key: v.optional(v.string()),
config: configValidator,
sampleShards: v.optional(v.number()),
},
returns: getValueReturns,
handler: async (ctx, args): Promise<GetValueReturns> => {
const config = configWithDefaults(args.config);
const samplesToTake = Math.min(args.sampleShards || 1, config.shards);
const shardIndices = Array.from({ length: config.shards }, (_, i) => i);
const selectedShards: number[] = [];
for (let i = 0; i < samplesToTake; i++) {
if (shardIndices.length === 0) break;
const randomIndex = Math.floor(Math.random() * shardIndices.length);
selectedShards.push(shardIndices[randomIndex]);
shardIndices.splice(randomIndex, 1);
}
const allShards = (
await Promise.all(
selectedShards.map((shard) =>
getShard(ctx.db, args.name, args.key, shard),
),
)
).map(
(state, i) =>
state ?? { value: config.capacity, ts: 0, shard: selectedShards[i]! },
);
const maxTs = Math.max(...allShards.map((shard) => shard.ts));
// we calculate the values as if each shard was at the latest ts
// we avoid passing Date.now() so the query isn't too time-aware.
const values = allShards.map((state) => ({
...state,
maxTs: calculateRateLimit(state, config, maxTs),
}));
const maxShard = values.reduce((a, b) =>
a.maxTs.value > b.maxTs.value ? a : b,
);
if (config.kind === "fixed window" && !config.start) {
// we can modify here b/c config is our copy
config.start = maxShard.maxTs.windowStart;
}
return {
value: maxShard.value,
ts: maxShard.ts,
shard: maxShard.shard,
config,
};
},
});
export const resetRateLimit = mutation({
args: {
name: v.string(),
key: v.optional(v.string()),
},
returns: v.null(),
handler: async (ctx, args) => {
const allShards = await ctx.db
.query("rateLimits")
.withIndex("name", (q) => q.eq("name", args.name).eq("key", args.key))
.collect();
for (const shard of allShards) {
await ctx.db.delete(shard._id);
}
},
});
export const clearAll = mutation({
args: { before: v.optional(v.number()) },
returns: v.null(),
handler: async (ctx, args) => {
const results = await ctx.db
.query("rateLimits")
.withIndex("by_creation_time", (q) =>
q.lte("_creationTime", args.before ?? Date.now()),
)
.order("desc")
.take(100);
for (const m of results) {
await ctx.db.delete(m._id);
}
if (results.length === 100) {
await ctx.scheduler.runAfter(0, api.lib.clearAll, {
before: results[99]._creationTime,
});
}
},
});
export { getServerTime } from "./time.js";