UNPKG

fastify-session-redis-store

Version:

Redis session store for @fastify/session

235 lines (207 loc) 6 kB
import { MemoryStore } from "@fastify/session"; export interface SessionData { id?: string; cookie: { originalMaxAge: number | null; maxAge?: number; signed?: boolean; expires?: Date | null; httpOnly?: boolean; path?: string; domain?: string; secure?: boolean | "auto"; sameSite?: boolean | "lax" | "strict" | "none"; }; } interface NormalizedRedisClient { get(key: string): Promise<string | null>; set(key: string, value: string, ttl?: number): Promise<string | null>; expire(key: string, ttl: number): Promise<number | boolean>; scanIterator(match: string, count: number): AsyncIterable<string>; del(key: string[]): Promise<number>; mget(key: string[]): Promise<(string | null)[]>; } export interface Serializer { parse(s: string): SessionData | Promise<SessionData>; stringify(s: SessionData): string; } export interface RedisStoreOptions { client: any; prefix?: string; scanCount?: number; serializer?: Serializer; ttl?: number | ((sess: SessionData) => number); disableTTL?: boolean; disableTouch?: boolean; } const noop = (_err?: unknown, _data?: any) => {}; export class RedisStore extends MemoryStore { client: NormalizedRedisClient; prefix: string; scanCount: number; serializer: Serializer; ttl: number | ((sess: SessionData) => number); disableTTL: boolean; disableTouch: boolean; constructor(opts: RedisStoreOptions) { super(); this.prefix = opts.prefix ?? "sess:"; this.scanCount = opts.scanCount ?? 100; this.serializer = opts.serializer ?? JSON; this.ttl = opts.ttl ?? 86400; // One day in seconds. this.disableTTL = opts.disableTTL ?? false; this.disableTouch = opts.disableTouch ?? false; this.client = this.normalizeClient(opts.client); } // Create a redis and ioredis compatible client private normalizeClient(client: any): NormalizedRedisClient { let isRedis = "scanIterator" in client; return { get: (key) => client.get(key), set: (key, val, ttl) => { if (ttl) { return isRedis ? client.set(key, val, { EX: ttl }) : client.set(key, val, "EX", ttl); } return client.set(key, val); }, del: (key) => client.del(key), expire: (key, ttl) => client.expire(key, ttl), mget: (keys) => (isRedis ? client.mGet(keys) : client.mget(keys)), scanIterator: (match, count) => { if (isRedis) return client.scanIterator({ MATCH: match, COUNT: count }); // ioredis impl. return (async function* () { let [c, xs] = await client.scan("0", "MATCH", match, "COUNT", count); for (let key of xs) yield key; while (c !== "0") { [c, xs] = await client.scan(c, "MATCH", match, "COUNT", count); for (let key of xs) yield key; } })(); }, }; } async get(sid: string, cb = noop) { let key = this.prefix + sid; try { let data = await this.client.get(key); if (!data) return cb(); return cb(null, await this.serializer.parse(data)); } catch (err) { return cb(err); } } async set(sid: string, sess: SessionData, cb = noop) { let key = this.prefix + sid; let ttl = this._getTTL(sess); try { let val = this.serializer.stringify(sess); if (ttl > 0) { if (this.disableTTL) await this.client.set(key, val); else await this.client.set(key, val, ttl); return cb(); } else { return this.destroy(sid, cb); } } catch (err) { return cb(err); } } async touch(sid: string, sess: SessionData, cb = noop) { let key = this.prefix + sid; if (this.disableTouch || this.disableTTL) return cb(); try { await this.client.expire(key, this._getTTL(sess)); return cb(); } catch (err) { return cb(err); } } async destroy(sid: string, cb = noop) { let key = this.prefix + sid; try { await this.client.del([key]); return cb(); } catch (err) { return cb(err); } } async clear(cb = noop) { try { let keys = await this._getAllKeys(); if (!keys.length) return cb(); await this.client.del(keys); return cb(); } catch (err) { return cb(err); } } async length(cb = noop) { try { let keys = await this._getAllKeys(); return cb(null, keys.length); } catch (err) { return cb(err); } } async ids(cb = noop) { let len = this.prefix.length; try { let keys = await this._getAllKeys(); return cb( null, keys.map((k) => k.substring(len)), ); } catch (err) { return cb(err); } } async all(cb = noop) { let len = this.prefix.length; try { let keys = await this._getAllKeys(); if (keys.length === 0) return cb(null, []); let data = await this.client.mget(keys); let results = await data.reduce<Promise<SessionData[]>>( async (acc, raw, idx) => { if (!raw) { return acc; } let sess = await this.serializer.parse(raw); sess.id = keys[idx].substring(len); const result = await acc; result.push(sess); return result; }, Promise.resolve([]), ); return cb(null, results); } catch (err) { return cb(err); } } private _getTTL(sess: SessionData) { if (typeof this.ttl === "function") { return this.ttl(sess); } let ttl; if (sess?.cookie?.expires) { let ms = Number(new Date(sess.cookie.expires)) - Date.now(); ttl = Math.ceil(ms / 1000); } else { ttl = this.ttl; } return ttl; } private async _getAllKeys() { let pattern = this.prefix + "*"; let keys = []; for await (let key of this.client.scanIterator(pattern, this.scanCount)) { keys.push(key); } return keys; } } export default RedisStore;