@cap-js-community/websocket
Version:
WebSocket adapter for CDS
332 lines (314 loc) • 10.8 kB
JavaScript
;
const cds = require("@sap/cds");
const { Server } = require("socket.io");
const SocketServer = require("./base");
const redis = require("../redis");
const LOG = cds.log("websocket/socket.io");
const DEBUG = cds.debug("websocket");
class SocketIOServer extends SocketServer {
constructor(server, path, config) {
super(server, path, config);
this.io = new Server(server, config?.options);
this.io.engine.on("connection_error", (err) => {
delete err.req;
LOG?.error(err);
});
cds.ws = this;
cds.io = this.io;
}
async setup() {
await this.applyAdapter();
}
service(service, path, connected) {
const io = this.applyMiddlewares(this.io.of(this.servicePath(path)));
const format = this.format(service, undefined, "json");
io.on("connection", async (socket) => {
try {
this.onInit(socket, socket.request);
socket.context = this.initContext();
socket.request.id ??= socket.request._query?.id;
socket.on("disconnect", () => {
this.onDisconnect(socket);
DEBUG?.("Disconnected", socket.id);
});
socket.facade = {
service,
path,
socket,
get context() {
return socket.context;
},
on: (event, callback) => {
socket.on(event, async (data, headers, fn) => {
try {
if (typeof headers === "function") {
fn = headers;
headers = undefined;
}
await callback(format.parse(data).data, headers, fn);
} catch (err) {
LOG?.error(err);
throw err;
}
});
},
emit: async (event, data, headers) => {
try {
await socket.emit(event, format.compose(event, data, headers));
} catch (err) {
LOG?.error(err);
throw err;
}
},
broadcast: async (event, data, headers, filter) => {
await this.broadcast({
tenant: socket.context.tenant,
service,
path,
event,
data,
headers,
filter,
socket,
});
},
broadcastAll: async (event, data, headers, filter) => {
await this.broadcast({
tenant: socket.context.tenant,
service,
path,
event,
data,
headers,
filter,
socket: null,
});
},
enter: async (context) => {
socket.contexts.add(context);
for (const entry of this.combineValues(socket, { context: [context] })) {
socket.join(room(entry));
}
},
exit: async (context) => {
socket.contexts.delete(context);
for (const entry of this.combineValues(socket, { context: [context] })) {
socket.leave(room(entry));
}
},
reset: async () => {
for (const context of socket.contexts) {
await socket.facade.exit(context);
}
},
disconnect() {
socket.disconnect();
},
onDisconnect: (callback) => {
socket.on("disconnect", callback);
},
};
socket.context.ws = { service: socket.facade, socket: socket, io };
this.onConnect(socket);
connected && connected(socket.facade);
DEBUG?.("Connected", socket.id);
} catch (err) {
LOG?.error(err);
}
});
}
async broadcast({
tenant,
service,
path,
event,
data,
headers,
filter: { user, role, context, identifier } = {},
socket,
}) {
try {
tenant = tenant || socket?.context.tenant;
path = path || this.defaultPath(service);
let to = socket?.broadcast || this.io.of(this.servicePath(path));
if (user?.include?.length || role?.include?.length || context?.include?.length || identifier?.include?.length) {
switch (this.serviceOperator(service, event, "include")) {
case "or":
default:
for (const userInclude of user?.include || []) {
to = to.to(room({ tenant, user: userInclude }));
}
for (const roleInclude of role?.include || []) {
to = to.to(room({ tenant, role: roleInclude }));
}
for (const contextInclude of context?.include || []) {
to = to.to(room({ tenant, context: contextInclude }));
}
for (const identifierInclude of identifier?.include || []) {
to = to.to(room({ tenant, identifier: identifierInclude }));
}
break;
case "and":
for (const entry of this.combineValues(undefined, {
tenant: [tenant],
user: user?.include?.length > 0 ? user.include : undefined,
role: role?.include?.length > 0 ? role.include : undefined,
context: context?.include?.length > 0 ? context.include : undefined,
identifier: identifier?.include?.length > 0 ? identifier.include : undefined,
})) {
to = to.to(room(entry));
}
break;
}
} else {
to = to.to(room({ tenant }));
}
if (user?.exclude?.length || role?.exclude?.length || context?.exclude?.length || identifier?.exclude?.length) {
switch (this.serviceOperator(service, event, "exclude")) {
case "or":
default:
for (const userExclude of user?.exclude || []) {
to = to.except(room({ tenant, user: userExclude }));
}
for (const roleExclude of role?.exclude || []) {
to = to.except(room({ tenant, role: roleExclude }));
}
for (const contextExclude of context?.exclude || []) {
to = to.except(room({ tenant, context: contextExclude }));
}
for (const identifierExclude of identifier?.exclude || []) {
to = to.except(room({ tenant, identifier: identifierExclude }));
}
break;
case "and":
for (const entry of this.combineValues(undefined, {
tenant: [tenant],
user: user?.exclude?.length > 0 ? user.exclude : undefined,
role: role?.exclude?.length > 0 ? role.exclude : undefined,
context: context?.exclude?.length > 0 ? context.exclude : undefined,
identifier: identifier?.exclude?.length > 0 ? identifier.exclude : undefined,
})) {
to = to.except(room(entry));
}
break;
}
}
const format = this.format(service, event, "json");
to.emit(event, format.compose(event, data, headers));
} catch (err) {
LOG?.error(err);
throw err;
}
}
close(socket) {
if (socket) {
socket.disconnect(true);
} else {
this.io.close();
}
}
onInit(socket, request) {
super.onInit(socket, request);
socket.contexts = new Set(); // Set<context>
}
onConnect(socket) {
const combinations = this.combineValues(socket);
for (const combination of combinations) {
socket.join(room(combination));
}
}
onDisconnect(socket) {
socket.contexts.clear();
}
combineValues(socket, filter) {
const values = {
tenant: [],
user: [],
roles: [],
contexts: [],
identifier: [],
};
if (socket) {
values.tenant = [socket.context.tenant];
values.user = [socket.context.user?.id];
const roles = this.serviceRoles(socket.facade.service);
if (roles?.length && socket.context.user?.is) {
for (const role of roles) {
if (socket.context.user.is(role)) {
values.roles.push(role);
}
}
}
values.identifier = [socket.request.id];
}
const combinations = [];
for (const tenant of unique(filter?.tenant ?? values.tenant)) {
for (const user of unique(filter?.user ?? [undefined, ...values.user])) {
for (const role of unique(filter?.role ?? [undefined, ...values.roles])) {
for (const context of unique(filter?.context ?? [undefined, ...values.contexts])) {
for (const identifier of unique(filter?.identifier ?? [undefined, ...values.identifier])) {
combinations.push({ tenant, user, role, context, identifier });
}
}
}
}
}
return combinations;
}
async applyAdapter() {
try {
const config = { ...this.config?.adapter };
if (config.impl) {
let client;
const options = { ...config?.options };
const adapterFactory = SocketServer.require(config.impl, "adapter");
if (adapterFactory) {
switch (config.impl) {
case "@socket.io/redis-adapter":
if (await redis.connectionCheck(config)) {
client = await Promise.all([
redis.createPrimaryClientAndConnect(config),
redis.createSecondaryClientAndConnect(config),
]);
if (client?.length === 2) {
this.adapter = adapterFactory.createAdapter(...client, options);
}
}
break;
case "@socket.io/redis-streams-adapter":
if (await redis.connectionCheck(config)) {
client = await redis.createPrimaryClientAndConnect(config);
if (client) {
this.adapter = adapterFactory.createAdapter(client, options);
}
}
break;
default:
this.adapter = adapterFactory.createAdapter(this, options, config);
break;
}
if (this.adapter) {
this.io.adapter(this.adapter);
this.adapterImpl = config.impl;
this.adapterActive = true;
}
}
}
} catch (err) {
LOG?.error(err);
}
}
applyMiddlewares(io) {
for (const middleware of this.middlewares()) {
io.use(middleware);
}
return io;
}
}
function unique(array) {
return [...new Set(array)];
}
function room({ tenant, user, role, context, identifier }) {
return `${tenant ? `/tenant:${tenant}#` : ""}${user ? `/user:${user}#` : ""}${role ? `/role:${role}#` : ""}${context ? `/context:${context}#` : ""}${identifier ? `/identifier:${identifier}#` : ""}`;
}
module.exports = SocketIOServer;