UNPKG

fa-session-redis

Version:

Redis session store for farrow-auth-session with support for both redis and ioredis clients

331 lines 12.7 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.createNormalizedRedisClient = createNormalizedRedisClient; exports.createRedisSessionStore = createRedisSessionStore; const farrow_auth_session_1 = require("farrow-auth-session"); const ulid_1 = require("ulid"); function isIoRedisClient(client) { return typeof client.scan === 'function' && typeof client.mget === 'function' && !client.scanIterator; } function isNodeRedisClient(client) { return typeof client.scanIterator === 'function' && typeof client.mGet === 'function'; } function createNormalizedRedisClient(client) { if (isIoRedisClient(client)) { return { get: async (key) => { return client.get(key); }, set: async (key, value) => { const result = await client.set(key, value); return result === 'OK' || result === '1'; }, setex: async (key, seconds, value) => { if (typeof client.setex === 'function') { const result = await client.setex(key, seconds, value); return result === 'OK' || result === '1'; } const setResult = await client.set(key, value); if (setResult === 'OK' || setResult === '1') { const expireResult = await client.expire(key, seconds); return expireResult === 1 || expireResult === true; } return false; }, del: async (keyOrKeys) => { const keys = Array.isArray(keyOrKeys) ? keyOrKeys : [keyOrKeys]; if (typeof client.del === 'function') { const result = await client.del(...keys); return result; } // Fallback for clients without proper del let deleted = 0; for (const key of keys) { const result = await client.del(key); deleted += result; } return deleted; }, expire: async (key, seconds) => { const result = await client.expire(key, seconds); return result === 1 || result === true; }, ttl: async (key) => { if (typeof client.ttl === 'function') { return client.ttl(key); } return -2; }, mget: async (keys) => { if (typeof client.mget === 'function') { return client.mget(...keys); } const results = []; for (const key of keys) { results.push(await client.get(key)); } return results; }, scanIterator: async function* (match, count) { if (typeof client.scan === 'function') { let cursor = '0'; do { const [nextCursor, keys] = await client.scan(cursor, 'MATCH', match, 'COUNT', count); for (const key of keys) { yield key; } cursor = nextCursor; } while (cursor !== '0'); } else { return; } } }; } else if (isNodeRedisClient(client)) { return { get: async (key) => { return client.get(key); }, set: async (key, value) => { const result = await client.set(key, value); return result === 'OK'; }, setex: async (key, seconds, value) => { if (typeof client.setEx === 'function') { const result = await client.setEx(key, seconds, value); return result === 'OK'; } const setResult = await client.set(key, value); if (setResult === 'OK') { return client.expire(key, seconds); } return false; }, del: async (keyOrKeys) => { return client.del(keyOrKeys); }, expire: async (key, seconds) => { const result = await client.expire(key, seconds); return result === true || result === 1; }, ttl: async (key) => { if (typeof client.ttl === 'function') { return client.ttl(key); } return -2; }, mget: async (keys) => { if (typeof client.mGet === 'function') { return client.mGet(keys); } const results = []; for (const key of keys) { results.push(await client.get(key)); } return results; }, scanIterator: async function* (match, count) { if (typeof client.scanIterator === 'function') { const iterator = client.scanIterator({ MATCH: match, COUNT: count }); for await (const batch of iterator) { // node-redis 返回的可能是批次数组 if (Array.isArray(batch)) { for (const key of batch) { yield key; } } else { yield batch; } } } else { // Fallback for clients without scanIterator return; } } }; } else { throw new Error('Unsupported Redis client type. Please use redis or ioredis.'); } } // 类型守卫函数 function isNormalizedRedisClient(client) { return ('setex' in client && 'ttl' in client && 'mget' in client && typeof client.setex === 'function' && typeof client.ttl === 'function' && typeof client.mget === 'function'); } // 实现函数 function createRedisSessionStore(client, // 接受any类型,在内部进行类型检查和转换 options = {}) { // 验证客户端是否具备基本的Redis接口 if (!client || typeof client !== 'object') { throw new Error('Redis client is required'); } const requiredMethods = ['get', 'set', 'del', 'expire']; for (const method of requiredMethods) { if (typeof client[method] !== 'function') { throw new Error(`Redis client must implement ${method} method`); } } // 使用类型守卫检查客户端类型 const normalizedClient = isNormalizedRedisClient(client) ? client : createNormalizedRedisClient(client); const config = { prefix: 'session', ttl: 86400, rolling: false, renew: false, renewBefore: 600, genSessionId: () => (0, ulid_1.ulid)(), defaultData: () => ({}), ...options }; const getKey = (sessionId) => `${config.prefix}:${sessionId}`; const store = { async get(sessionId) { if (!sessionId) { return null; } const key = getKey(sessionId); const data = await normalizedClient.get(key); if (!data) { return null; } try { const userData = JSON.parse(data); // Update session metadata in context const expiresTime = config.ttl !== false ? Date.now() + (config.ttl * 1000) : Date.now() + (365 * 24 * 60 * 60 * 1000); // 1 year if no TTL farrow_auth_session_1.sessionMetaDataCtx.set({ sessionId, expiresTime }); // Handle rolling and renew strategies if (config.rolling && config.ttl !== false) { await normalizedClient.expire(key, config.ttl); } else if (config.renew && config.ttl !== false) { const ttl = await normalizedClient.ttl(key); if (ttl > 0 && ttl < config.renewBefore) { await normalizedClient.expire(key, config.ttl); } } return userData; } catch (error) { console.error('Failed to parse session data:', error); return undefined; } }, async set(userData) { const sessionMeta = farrow_auth_session_1.sessionMetaDataCtx.get(); if (!sessionMeta?.sessionId) { return undefined; } const key = getKey(sessionMeta.sessionId); try { const data = JSON.stringify(userData); if (config.ttl !== false) { const result = await normalizedClient.setex(key, config.ttl, data); return result ? true : false; } else { const result = await normalizedClient.set(key, data); return result ? true : false; } } catch (error) { console.error('Failed to save session:', error); return undefined; } }, async create(userData) { const sessionId = config.genSessionId(); const data = userData || config.defaultData(); const key = getKey(sessionId); try { const jsonData = JSON.stringify(data); let result; if (config.ttl !== false) { result = await normalizedClient.setex(key, config.ttl, jsonData); } else { result = await normalizedClient.set(key, jsonData); } if (result) { // Set session metadata in context for parser to use const expiresTime = config.ttl !== false ? Date.now() + (config.ttl * 1000) : Date.now() + (365 * 24 * 60 * 60 * 1000); // 1 year if no TTL farrow_auth_session_1.sessionMetaDataCtx.set({ sessionId, expiresTime }); return data; } return undefined; } catch (error) { console.error('Failed to create session:', error); return undefined; } }, async destroy() { const sessionMeta = farrow_auth_session_1.sessionMetaDataCtx.get(); if (!sessionMeta?.sessionId) { return false; } const key = getKey(sessionMeta.sessionId); try { const result = await normalizedClient.del(key); // Clear session metadata regardless of result farrow_auth_session_1.sessionMetaDataCtx.set(undefined); return result > 0; } catch (error) { console.error('Failed to destroy session:', error); return undefined; } }, async touch() { const sessionMeta = farrow_auth_session_1.sessionMetaDataCtx.get(); if (!sessionMeta?.sessionId) { return false; } if (config.ttl === false) { return false; // No TTL to update } const key = getKey(sessionMeta.sessionId); try { const result = await normalizedClient.expire(key, config.ttl); if (result) { // Update expiry time in context farrow_auth_session_1.sessionMetaDataCtx.set({ ...sessionMeta, expiresTime: Date.now() + (config.ttl * 1000) }); return true; } return false; } catch (error) { console.error('Failed to touch session:', error); return undefined; } } }; return store; } //# sourceMappingURL=index.js.map