UNPKG

@directus/api

Version:

Directus is a real-time API and App dashboard for managing SQL database content

106 lines (105 loc) 5.15 kB
import { CloseCode, MessageType, makeServer } from 'graphql-ws'; import { useLogger } from '../../logger/index.js'; import { createDefaultAccountability } from '../../permissions/utils/create-default-accountability.js'; import { bindPubSub } from '../../services/graphql/subscription.js'; import { GraphQLService } from '../../services/index.js'; import { getAddress } from '../../utils/get-address.js'; import { getSchema } from '../../utils/get-schema.js'; import { authenticateConnection } from '../authenticate.js'; import { handleWebSocketError } from '../errors.js'; import { ConnectionParams, WebSocketMessage } from '../messages.js'; import { getMessageType } from '../utils/message.js'; import SocketController from './base.js'; import { registerWebSocketEvents } from './hooks.js'; const logger = useLogger(); export class GraphQLSubscriptionController extends SocketController { gql; constructor(httpServer) { super(httpServer, 'WEBSOCKETS_GRAPHQL'); registerWebSocketEvents(); this.server.on('connection', (ws, auth) => { this.bindEvents(this.createClient(ws, auth)); }); this.gql = makeServer({ schema: async (ctx) => { const accountability = ctx.extra.client.accountability; // for now only the items will be watched, system events tbd const service = new GraphQLService({ schema: await getSchema(), scope: 'items', accountability, }); return service.getSchema(); }, }); bindPubSub(); logger.info(`GraphQL Subscriptions started at ${getAddress(httpServer)}${this.endpoint}`); } bindEvents(client) { const closedHandler = this.gql.opened({ protocol: client.protocol, send: (data) => new Promise((resolve, reject) => { client.send(data, (err) => (err ? reject(err) : resolve())); }), close: (code, reason) => client.close(code, reason), // for standard closures onMessage: (cb) => { client.on('parsed-message', async (message) => { try { if (getMessageType(message) === 'connection_init' && this.authentication.mode !== 'strict') { const params = ConnectionParams.parse(message['payload'] ?? {}); if (this.authentication.mode === 'handshake') { if (typeof params.access_token === 'string') { const { accountability, expires_at } = await authenticateConnection({ access_token: params.access_token, }, { ip: client.accountability?.ip ?? null, }); client.accountability = accountability; client.expires_at = expires_at; } else { client.close(CloseCode.Forbidden, 'Forbidden'); return; } } } else if (this.authentication.mode === 'handshake' && !client.accountability?.user) { // the first message should authenticate successfully in this mode client.close(CloseCode.Forbidden, 'Forbidden'); return; } await cb(JSON.stringify(message)); } catch (error) { handleWebSocketError(client, error, MessageType.Error); } }); }, }, { client }); // notify server that the socket closed client.once('close', (code, reason) => closedHandler(code, reason.toString())); // check strict authentication status if (this.authentication.mode === 'strict' && !client.accountability?.user) { client.close(CloseCode.Forbidden, 'Forbidden'); } } setTokenExpireTimer(client) { if (client.auth_timer !== null) { clearTimeout(client.auth_timer); client.auth_timer = null; } if (this.authentication.mode !== 'handshake') return; client.auth_timer = setTimeout(() => { if (!client.accountability?.user) { client.close(CloseCode.Forbidden, 'Forbidden'); } }, this.authentication.timeout); } async handleHandshakeUpgrade({ request, socket, head }) { this.server.handleUpgrade(request, socket, head, async (ws) => { this.server.emit('connection', ws, { accountability: createDefaultAccountability(), expires_at: null }); // actual enforcement is handled by the setTokenExpireTimer function }); } }