UNPKG

@directus/api

Version:

Directus is a real-time API and App dashboard for managing SQL database content

353 lines (352 loc) 14.9 kB
import { useEnv } from '@directus/env'; import { InvalidProviderConfigError, TokenExpiredError } from '@directus/errors'; import { parseJSON, toBoolean } from '@directus/utils'; import cookie from 'cookie'; import { randomUUID } from 'node:crypto'; import { parse } from 'url'; import WebSocket, { WebSocketServer } from 'ws'; import { fromZodError } from 'zod-validation-error'; import emitter from '../../emitter.js'; import { useLogger } from '../../logger/index.js'; import { createDefaultAccountability } from '../../permissions/utils/create-default-accountability.js'; import { createRateLimiter } from '../../rate-limiter.js'; import { getIPFromReq } from '../../utils/get-ip-from-req.js'; import { authenticateConnection, authenticationSuccess } from '../authenticate.js'; import { WebSocketError, handleWebSocketError } from '../errors.js'; import { AuthMode, WebSocketAuthMessage, WebSocketMessage } from '../messages.js'; import { getMessageType } from '../utils/message.js'; import { waitForAnyMessage, waitForMessageType } from '../utils/wait-for-message.js'; const TOKEN_CHECK_INTERVAL = 15 * 60 * 1000; // 15 minutes const logger = useLogger(); export default class SocketController { server; clients; authentication; endpoint; maxConnections; rateLimiter; authInterval; constructor(httpServer, configPrefix) { this.server = new WebSocketServer({ noServer: true, // @ts-ignore TODO Remove once @types/ws has been updated autoPong: false, }); this.clients = new Set(); this.authInterval = null; const { endpoint, authentication, maxConnections } = this.getEnvironmentConfig(configPrefix); this.endpoint = endpoint; this.authentication = authentication; this.maxConnections = maxConnections; this.rateLimiter = this.getRateLimiter(); httpServer.on('upgrade', this.handleUpgrade.bind(this)); this.checkClientTokens(); } getEnvironmentConfig(configPrefix) { const env = useEnv(); const endpoint = String(env[`${configPrefix}_PATH`]); const authMode = AuthMode.safeParse(String(env[`${configPrefix}_AUTH`]).toLowerCase()); const authTimeout = Number(env[`${configPrefix}_AUTH_TIMEOUT`]) * 1000; const maxConnections = `${configPrefix}_CONN_LIMIT` in env ? Number(env[`${configPrefix}_CONN_LIMIT`]) : Number.POSITIVE_INFINITY; if (!authMode.success) { throw new InvalidProviderConfigError({ provider: 'ws', reason: fromZodError(authMode.error, { prefix: `${configPrefix}_AUTH` }).message, }); } return { endpoint, maxConnections, authentication: { mode: authMode.data, timeout: authTimeout, }, }; } getRateLimiter() { const env = useEnv(); if (toBoolean(env['RATE_LIMITER_ENABLED']) === true) { return createRateLimiter('RATE_LIMITER', { keyPrefix: 'websocket', }); } return null; } catchInvalidMessages(ws) { /** * This fix was done to prevent the API from crashing on receiving invalid WebSocket frames * https://github.com/directus/directus/security/advisories/GHSA-hmgw-9jrg-hf2m * https://github.com/websockets/ws/issues/2098 */ // @ts-ignore <- required because "_socket" is not typed on WS ws._socket.prependListener('data', (data) => data.toString()); ws.on('error', (error) => { if (error.message) logger.debug(error.message); }); } async handleUpgrade(request, socket, head) { const { pathname, query } = parse(request.url, true); if (pathname !== this.endpoint) return; if (this.clients.size >= this.maxConnections) { logger.debug('WebSocket upgrade denied - max connections reached'); socket.write('HTTP/1.1 403 Forbidden\r\n\r\n'); socket.destroy(); return; } const env = useEnv(); const cookies = request.headers.cookie ? cookie.parse(request.headers.cookie) : {}; const sessionCookieName = env['SESSION_COOKIE_NAME']; const accountabilityOverrides = { ip: getIPFromReq(request) ?? null, }; const userAgent = request.headers['user-agent']?.substring(0, 1024); if (userAgent) accountabilityOverrides.userAgent = userAgent; const origin = request.headers['origin']; if (origin) accountabilityOverrides.origin = origin; const context = { request, socket, head, accountabilityOverrides }; if (this.authentication.mode === 'strict' || query['access_token'] || cookies[sessionCookieName]) { let token = null; if (typeof query['access_token'] === 'string') { token = query['access_token']; } else if (typeof cookies[sessionCookieName] === 'string') { token = cookies[sessionCookieName] ?? null; } await this.handleTokenUpgrade(context, token); return; } if (this.authentication.mode === 'handshake') { await this.handleHandshakeUpgrade(context); return; } this.server.handleUpgrade(request, socket, head, async (ws) => { this.catchInvalidMessages(ws); const state = { accountability: createDefaultAccountability(accountabilityOverrides), expires_at: null, }; this.server.emit('connection', ws, state); }); } async handleTokenUpgrade({ request, socket, head, accountabilityOverrides }, token) { let accountability = null; let expires_at = null; if (token) { try { const state = await authenticateConnection({ access_token: token }, accountabilityOverrides); accountability = state.accountability; expires_at = state.expires_at; } catch { accountability = null; expires_at = null; } } if (!token || !accountability || !accountability.user) { logger.debug('WebSocket upgrade denied - ' + JSON.stringify(accountability || 'invalid')); socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); socket.destroy(); return; } try { this.checkUserRequirements(accountability); } catch { logger.debug('WebSocket upgrade denied - ' + JSON.stringify(accountability || 'invalid')); socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); socket.destroy(); return; } this.server.handleUpgrade(request, socket, head, async (ws) => { this.catchInvalidMessages(ws); const state = { accountability, expires_at }; this.server.emit('connection', ws, state); }); } async handleHandshakeUpgrade({ request, socket, head, accountabilityOverrides }) { this.server.handleUpgrade(request, socket, head, async (ws) => { this.catchInvalidMessages(ws); try { const payload = await waitForAnyMessage(ws, this.authentication.timeout); if (getMessageType(payload) !== 'auth') throw new Error(); const state = await authenticateConnection(WebSocketAuthMessage.parse(payload), accountabilityOverrides); this.checkUserRequirements(state.accountability); ws.send(authenticationSuccess(payload['uid'], state.refresh_token)); this.server.emit('connection', ws, state); } catch { logger.debug('WebSocket authentication handshake failed'); const error = new WebSocketError('auth', 'AUTH_FAILED', 'Authentication handshake failed.'); handleWebSocketError(ws, error, 'auth'); ws.close(); } }); } createClient(ws, { accountability, expires_at }) { const client = ws; client.accountability = accountability; client.expires_at = expires_at; client.uid = randomUUID(); client.auth_timer = null; ws.on('message', async (data) => { if (this.rateLimiter !== null) { try { await this.rateLimiter.consume(client.uid); } catch (limit) { const timeout = limit?.msBeforeNext ?? this.rateLimiter.msDuration; const error = new WebSocketError('server', 'REQUESTS_EXCEEDED', `Too many messages, retry after ${timeout}ms.`); handleWebSocketError(client, error, 'server'); logger.debug(`WebSocket#${client.uid} is rate limited`); return; } } let message; try { message = this.parseMessage(data.toString()); } catch (err) { handleWebSocketError(client, err, 'server'); return; } if (getMessageType(message) === 'auth') { try { await this.handleAuthRequest(client, WebSocketAuthMessage.parse(message)); } catch { // ignore errors } return; } // this log cannot be higher in the function or it will leak credentials logger.trace(`WebSocket#${client.uid} - ${JSON.stringify(message)}`); ws.emit('parsed-message', message); }); ws.on('error', () => { logger.debug(`WebSocket#${client.uid} connection errored`); if (client.auth_timer) { clearTimeout(client.auth_timer); client.auth_timer = null; } this.clients.delete(client); }); ws.on('close', () => { logger.debug(`WebSocket#${client.uid} connection closed`); if (client.auth_timer) { clearTimeout(client.auth_timer); client.auth_timer = null; } this.clients.delete(client); }); logger.debug(`WebSocket#${client.uid} connected`); if (accountability) { logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(accountability)}`); } this.setTokenExpireTimer(client); this.clients.add(client); return client; } parseMessage(data) { let message; try { message = WebSocketMessage.parse(parseJSON(data)); } catch { throw new WebSocketError('server', 'INVALID_PAYLOAD', 'Unable to parse the incoming message.'); } return message; } async handleAuthRequest(client, message) { try { let accountabilityOverrides = {}; /** * Re-use the existing ip, userAgent and origin accountability properties. * They are only sent in the original connection request */ if (client.accountability) { accountabilityOverrides = { ip: client.accountability.ip, userAgent: client.accountability.userAgent, origin: client.accountability.origin, }; } const { accountability, expires_at, refresh_token } = await authenticateConnection(message, accountabilityOverrides); this.checkUserRequirements(accountability); client.accountability = accountability; client.expires_at = expires_at; this.setTokenExpireTimer(client); emitter.emitAction('websocket.auth.success', { client }); client.send(authenticationSuccess(message.uid, refresh_token)); logger.trace(`WebSocket#${client.uid} authenticated as ${JSON.stringify(client.accountability)}`); } catch (error) { logger.trace(`WebSocket#${client.uid} failed authentication`); emitter.emitAction('websocket.auth.failure', { client }); client.accountability = null; client.expires_at = null; const _error = error instanceof WebSocketError ? error : new WebSocketError('auth', 'AUTH_FAILED', 'Authentication failed.', message.uid); handleWebSocketError(client, _error, 'auth'); if (this.authentication.mode !== 'public') { client.close(); } } } checkUserRequirements(_accountability) { // there are no requirements in the abstract class return; } setTokenExpireTimer(client) { if (client.auth_timer !== null) { // clear up old timeouts if needed clearTimeout(client.auth_timer); client.auth_timer = null; } if (!client.expires_at) return; const expiresIn = client.expires_at * 1000 - Date.now(); if (expiresIn > TOKEN_CHECK_INTERVAL) return; client.auth_timer = setTimeout(() => { client.accountability = null; client.expires_at = null; handleWebSocketError(client, new TokenExpiredError(), 'auth'); waitForMessageType(client, 'auth', this.authentication.timeout).catch((msg) => { const error = new WebSocketError('auth', 'AUTH_TIMEOUT', 'Authentication timed out.', msg?.uid); handleWebSocketError(client, error, 'auth'); if (this.authentication.mode !== 'public') { client.close(); } }); }, expiresIn); } checkClientTokens() { this.authInterval = setInterval(() => { if (this.clients.size === 0) return; // check the clients and set shorter timeouts if needed for (const client of this.clients) { if (client.expires_at === null || client.auth_timer !== null) continue; this.setTokenExpireTimer(client); } }, TOKEN_CHECK_INTERVAL); } terminate() { if (this.authInterval) clearInterval(this.authInterval); this.clients.forEach((client) => { if (client.auth_timer) clearTimeout(client.auth_timer); }); this.server.clients.forEach((ws) => { ws.terminate(); }); } }