UNPKG

@cap-js-community/websocket

Version:
332 lines (314 loc) 10.8 kB
"use strict"; const cds = require("@sap/cds"); const { Server } = require("socket.io"); const SocketServer = require("./base"); const redis = require("../redis"); const LOG = cds.log("websocket/socket.io"); const DEBUG = cds.debug("websocket"); class SocketIOServer extends SocketServer { constructor(server, path, config) { super(server, path, config); this.io = new Server(server, config?.options); this.io.engine.on("connection_error", (err) => { delete err.req; LOG?.error(err); }); cds.ws = this; cds.io = this.io; } async setup() { await this.applyAdapter(); } service(service, path, connected) { const io = this.applyMiddlewares(this.io.of(this.servicePath(path))); const format = this.format(service, undefined, "json"); io.on("connection", async (socket) => { try { this.onInit(socket, socket.request); socket.context = this.initContext(); socket.request.id ??= socket.request._query?.id; socket.on("disconnect", () => { this.onDisconnect(socket); DEBUG?.("Disconnected", socket.id); }); socket.facade = { service, path, socket, get context() { return socket.context; }, on: (event, callback) => { socket.on(event, async (data, headers, fn) => { try { if (typeof headers === "function") { fn = headers; headers = undefined; } await callback(format.parse(data).data, headers, fn); } catch (err) { LOG?.error(err); throw err; } }); }, emit: async (event, data, headers) => { try { await socket.emit(event, format.compose(event, data, headers)); } catch (err) { LOG?.error(err); throw err; } }, broadcast: async (event, data, headers, filter) => { await this.broadcast({ tenant: socket.context.tenant, service, path, event, data, headers, filter, socket, }); }, broadcastAll: async (event, data, headers, filter) => { await this.broadcast({ tenant: socket.context.tenant, service, path, event, data, headers, filter, socket: null, }); }, enter: async (context) => { socket.contexts.add(context); for (const entry of this.combineValues(socket, { context: [context] })) { socket.join(room(entry)); } }, exit: async (context) => { socket.contexts.delete(context); for (const entry of this.combineValues(socket, { context: [context] })) { socket.leave(room(entry)); } }, reset: async () => { for (const context of socket.contexts) { await socket.facade.exit(context); } }, disconnect() { socket.disconnect(); }, onDisconnect: (callback) => { socket.on("disconnect", callback); }, }; socket.context.ws = { service: socket.facade, socket: socket, io }; this.onConnect(socket); connected && connected(socket.facade); DEBUG?.("Connected", socket.id); } catch (err) { LOG?.error(err); } }); } async broadcast({ tenant, service, path, event, data, headers, filter: { user, role, context, identifier } = {}, socket, }) { try { tenant = tenant || socket?.context.tenant; path = path || this.defaultPath(service); let to = socket?.broadcast || this.io.of(this.servicePath(path)); if (user?.include?.length || role?.include?.length || context?.include?.length || identifier?.include?.length) { switch (this.serviceOperator(service, event, "include")) { case "or": default: for (const userInclude of user?.include || []) { to = to.to(room({ tenant, user: userInclude })); } for (const roleInclude of role?.include || []) { to = to.to(room({ tenant, role: roleInclude })); } for (const contextInclude of context?.include || []) { to = to.to(room({ tenant, context: contextInclude })); } for (const identifierInclude of identifier?.include || []) { to = to.to(room({ tenant, identifier: identifierInclude })); } break; case "and": for (const entry of this.combineValues(undefined, { tenant: [tenant], user: user?.include?.length > 0 ? user.include : undefined, role: role?.include?.length > 0 ? role.include : undefined, context: context?.include?.length > 0 ? context.include : undefined, identifier: identifier?.include?.length > 0 ? identifier.include : undefined, })) { to = to.to(room(entry)); } break; } } else { to = to.to(room({ tenant })); } if (user?.exclude?.length || role?.exclude?.length || context?.exclude?.length || identifier?.exclude?.length) { switch (this.serviceOperator(service, event, "exclude")) { case "or": default: for (const userExclude of user?.exclude || []) { to = to.except(room({ tenant, user: userExclude })); } for (const roleExclude of role?.exclude || []) { to = to.except(room({ tenant, role: roleExclude })); } for (const contextExclude of context?.exclude || []) { to = to.except(room({ tenant, context: contextExclude })); } for (const identifierExclude of identifier?.exclude || []) { to = to.except(room({ tenant, identifier: identifierExclude })); } break; case "and": for (const entry of this.combineValues(undefined, { tenant: [tenant], user: user?.exclude?.length > 0 ? user.exclude : undefined, role: role?.exclude?.length > 0 ? role.exclude : undefined, context: context?.exclude?.length > 0 ? context.exclude : undefined, identifier: identifier?.exclude?.length > 0 ? identifier.exclude : undefined, })) { to = to.except(room(entry)); } break; } } const format = this.format(service, event, "json"); to.emit(event, format.compose(event, data, headers)); } catch (err) { LOG?.error(err); throw err; } } close(socket) { if (socket) { socket.disconnect(true); } else { this.io.close(); } } onInit(socket, request) { super.onInit(socket, request); socket.contexts = new Set(); // Set<context> } onConnect(socket) { const combinations = this.combineValues(socket); for (const combination of combinations) { socket.join(room(combination)); } } onDisconnect(socket) { socket.contexts.clear(); } combineValues(socket, filter) { const values = { tenant: [], user: [], roles: [], contexts: [], identifier: [], }; if (socket) { values.tenant = [socket.context.tenant]; values.user = [socket.context.user?.id]; const roles = this.serviceRoles(socket.facade.service); if (roles?.length && socket.context.user?.is) { for (const role of roles) { if (socket.context.user.is(role)) { values.roles.push(role); } } } values.identifier = [socket.request.id]; } const combinations = []; for (const tenant of unique(filter?.tenant ?? values.tenant)) { for (const user of unique(filter?.user ?? [undefined, ...values.user])) { for (const role of unique(filter?.role ?? [undefined, ...values.roles])) { for (const context of unique(filter?.context ?? [undefined, ...values.contexts])) { for (const identifier of unique(filter?.identifier ?? [undefined, ...values.identifier])) { combinations.push({ tenant, user, role, context, identifier }); } } } } } return combinations; } async applyAdapter() { try { const config = { ...this.config?.adapter }; if (config.impl) { let client; const options = { ...config?.options }; const adapterFactory = SocketServer.require(config.impl, "adapter"); if (adapterFactory) { switch (config.impl) { case "@socket.io/redis-adapter": if (await redis.connectionCheck(config)) { client = await Promise.all([ redis.createPrimaryClientAndConnect(config), redis.createSecondaryClientAndConnect(config), ]); if (client?.length === 2) { this.adapter = adapterFactory.createAdapter(...client, options); } } break; case "@socket.io/redis-streams-adapter": if (await redis.connectionCheck(config)) { client = await redis.createPrimaryClientAndConnect(config); if (client) { this.adapter = adapterFactory.createAdapter(client, options); } } break; default: this.adapter = adapterFactory.createAdapter(this, options, config); break; } if (this.adapter) { this.io.adapter(this.adapter); this.adapterImpl = config.impl; this.adapterActive = true; } } } } catch (err) { LOG?.error(err); } } applyMiddlewares(io) { for (const middleware of this.middlewares()) { io.use(middleware); } return io; } } function unique(array) { return [...new Set(array)]; } function room({ tenant, user, role, context, identifier }) { return `${tenant ? `/tenant:${tenant}#` : ""}${user ? `/user:${user}#` : ""}${role ? `/role:${role}#` : ""}${context ? `/context:${context}#` : ""}${identifier ? `/identifier:${identifier}#` : ""}`; } module.exports = SocketIOServer;