@directus/api
Version:
Directus is a real-time API and App dashboard for managing SQL database content
106 lines (105 loc) • 5.15 kB
JavaScript
import { CloseCode, MessageType, makeServer } from 'graphql-ws';
import { useLogger } from '../../logger/index.js';
import { createDefaultAccountability } from '../../permissions/utils/create-default-accountability.js';
import { bindPubSub } from '../../services/graphql/subscription.js';
import { GraphQLService } from '../../services/index.js';
import { getAddress } from '../../utils/get-address.js';
import { getSchema } from '../../utils/get-schema.js';
import { authenticateConnection } from '../authenticate.js';
import { handleWebSocketError } from '../errors.js';
import { ConnectionParams, WebSocketMessage } from '../messages.js';
import { getMessageType } from '../utils/message.js';
import SocketController from './base.js';
import { registerWebSocketEvents } from './hooks.js';
const logger = useLogger();
export class GraphQLSubscriptionController extends SocketController {
gql;
constructor(httpServer) {
super(httpServer, 'WEBSOCKETS_GRAPHQL');
registerWebSocketEvents();
this.server.on('connection', (ws, auth) => {
this.bindEvents(this.createClient(ws, auth));
});
this.gql = makeServer({
schema: async (ctx) => {
const accountability = ctx.extra.client.accountability;
// for now only the items will be watched, system events tbd
const service = new GraphQLService({
schema: await getSchema(),
scope: 'items',
accountability,
});
return service.getSchema();
},
});
bindPubSub();
logger.info(`GraphQL Subscriptions started at ${getAddress(httpServer)}${this.endpoint}`);
}
bindEvents(client) {
const closedHandler = this.gql.opened({
protocol: client.protocol,
send: (data) => new Promise((resolve, reject) => {
client.send(data, (err) => (err ? reject(err) : resolve()));
}),
close: (code, reason) => client.close(code, reason), // for standard closures
onMessage: (cb) => {
client.on('parsed-message', async (message) => {
try {
if (getMessageType(message) === 'connection_init' && this.authentication.mode !== 'strict') {
const params = ConnectionParams.parse(message['payload'] ?? {});
if (this.authentication.mode === 'handshake') {
if (typeof params.access_token === 'string') {
const { accountability, expires_at } = await authenticateConnection({
access_token: params.access_token,
}, {
ip: client.accountability?.ip ?? null,
});
client.accountability = accountability;
client.expires_at = expires_at;
}
else {
client.close(CloseCode.Forbidden, 'Forbidden');
return;
}
}
}
else if (this.authentication.mode === 'handshake' && !client.accountability?.user) {
// the first message should authenticate successfully in this mode
client.close(CloseCode.Forbidden, 'Forbidden');
return;
}
await cb(JSON.stringify(message));
}
catch (error) {
handleWebSocketError(client, error, MessageType.Error);
}
});
},
}, { client });
// notify server that the socket closed
client.once('close', (code, reason) => closedHandler(code, reason.toString()));
// check strict authentication status
if (this.authentication.mode === 'strict' && !client.accountability?.user) {
client.close(CloseCode.Forbidden, 'Forbidden');
}
}
setTokenExpireTimer(client) {
if (client.auth_timer !== null) {
clearTimeout(client.auth_timer);
client.auth_timer = null;
}
if (this.authentication.mode !== 'handshake')
return;
client.auth_timer = setTimeout(() => {
if (!client.accountability?.user) {
client.close(CloseCode.Forbidden, 'Forbidden');
}
}, this.authentication.timeout);
}
async handleHandshakeUpgrade({ request, socket, head }) {
this.server.handleUpgrade(request, socket, head, async (ws) => {
this.server.emit('connection', ws, { accountability: createDefaultAccountability(), expires_at: null });
// actual enforcement is handled by the setTokenExpireTimer function
});
}
}