@cap-js-community/websocket
Version:
WebSocket adapter for CDS
833 lines (784 loc) • 26.3 kB
JavaScript
;
const cds = require("@sap/cds");
const SocketServer = require("./socket/base");
const redis = require("./redis");
const LOG = cds.log("/websocket");
const TIMEOUT_SHUTDOWN = 2500;
const WebSocketAction = {
Connect: "wsConnect",
Disconnect: "wsDisconnect",
Context: "wsContext",
};
let socketServer;
let services;
const collectServicesAndMountAdapter = (srv, options) => {
if (!services) {
services = {};
cds.on("served", () => {
options.services = services;
serveWebSocketServer(options);
});
cds.on("shutdown", async () => {
await shutdownWebSocketServer();
});
}
services[srv.name] = srv;
};
function serveWebSocketServer(options) {
// Wait for server listening (http server is ready)
cds.on("listening", async (app) => {
await bootstrapWebSocketServer(app.server, options);
});
}
async function bootstrapWebSocketServer(server, options) {
socketServer = await initWebSocketServer(server, options.path);
if (socketServer) {
// Websocket services
for (const serviceName in options.services) {
const service = options.services[serviceName];
if (isServedViaWebsocket(service)) {
serveWebSocketService(socketServer, service, options);
}
}
// Websockets events
const eventServices = {};
for (const name in cds.model.definitions) {
const definition = cds.model.definitions[name];
if (definition.kind === "event" && (definition["@websocket"] || definition["@ws"])) {
const service = cds.services[definition._service?.name];
if (service && !isServedViaWebsocket(service)) {
eventServices[service.name] ??= eventServices[service.name] || {
name: service.name,
definition: service.definition,
endpoints: service.endpoints.map((endpoint) => {
const protocol =
cds.env.protocols[endpoint.kind] || (endpoint.kind === "odata" ? cds.env.protocols["odata-v4"] : null);
let path = normalizeServicePath(service.path, protocol.path);
if (!path.startsWith("/")) {
path = (cds.env.protocols?.websocket?.path || cds.env.protocols?.ws?.path || "/ws") + "/" + path;
}
return { kind: "websocket", path };
}),
operations: () => {
return interableObject();
},
entities: () => {
return interableObject();
},
_events: interableObject(),
events: function () {
return this._events;
},
on: service.on.bind(service),
tx: service.tx.bind(service),
};
eventServices[service.name]._events[serviceLocalName(service, definition.name)] = definition;
}
}
}
for (const name in eventServices) {
const eventService = eventServices[name];
if (Object.keys(eventService.events()).length > 0) {
serveWebSocketService(socketServer, eventService, options);
}
}
LOG?.info("using websocket", {
kind: cds.env.websocket.kind,
adapter: socketServer.adapter ? { impl: socketServer.adapterImpl, active: socketServer.adapterActive } : false,
});
}
}
async function initWebSocketServer(server, path) {
if (cds.env.websocket === false) {
return;
}
try {
cds.env.websocket ??= {};
cds.env.websocket = { ...cds.env.requires?.websocket, ...cds.env.websocket };
cds.env.websocket.kind ??= "ws";
const serverImpl = cds.env.websocket.impl || cds.env.websocket.kind;
const socketModule = SocketServer.require(serverImpl, "socket");
socketServer = new socketModule(server, path, cds.env.websocket);
await socketServer.setup();
return socketServer;
} catch (err) {
LOG?.error(err);
}
}
async function shutdownWebSocketServer() {
return await new Promise((resolve, reject) => {
const timeoutRef = setTimeout(() => {
clearTimeout(timeoutRef);
LOG?.info("Shutdown timeout reached!");
resolve();
}, TIMEOUT_SHUTDOWN);
redis
.closeClients()
.then((result) => {
clearTimeout(timeoutRef);
resolve(result);
})
.catch((err) => {
clearTimeout(timeoutRef);
reject(err);
});
});
}
function normalizeServicePath(servicePath, protocolPath) {
if (servicePath.startsWith(`${protocolPath}/`)) {
return servicePath.substring(`${protocolPath}/`.length);
}
return servicePath;
}
function serveWebSocketService(socketServer, service, options) {
for (const endpoint of service.endpoints || []) {
if (["websocket", "ws"].includes(endpoint.kind)) {
const path = normalizeServicePath(endpoint.path, options.path);
try {
bindServiceEvents(socketServer, service, path);
socketServer.service(service, path, (socket) => {
try {
bindServiceDefaults(socket, service);
bindServiceOperations(socket, service);
bindServiceEntities(socket, service);
emitConnect(socket, service);
} catch (err) {
LOG?.error(err);
socket.disconnect();
}
});
} catch (err) {
LOG?.error(err);
}
}
}
}
function bindServiceEvents(socketServer, service, path) {
for (const event of service.events()) {
service.on(event, async (req) => {
try {
const localEventName = serviceLocalName(service, event.name);
const format = deriveFormat(service, event);
const headers = deriveHeaders(req.headers, format);
const user = deriveUser(event, req.data, headers, req);
const context = deriveContext(event, req.data, headers);
const identifier = deriveIdentifier(event, req.data, headers);
const eventHeaders = deriveEventHeaders(headers);
const eventPath = derivePath(event, path);
await socketServer.broadcast({
service,
path: eventPath,
event: localEventName,
data: req.data,
tenant: req.tenant,
user,
context,
identifier,
headers: eventHeaders,
socket: null,
});
} catch (err) {
LOG?.error(err);
}
});
}
}
function bindServiceDefaults(socket, service) {
if (service.operations[WebSocketAction.Disconnect]) {
const operation = service.operations[WebSocketAction.Disconnect];
socket.onDisconnect(async (reason) => {
const data = {};
if (reason !== undefined && operation?.params?.reason?.type === "cds.String") {
data.reason = stringValue(reason);
}
await processEvent(socket, service, WebSocketAction.Disconnect, data);
});
}
socket.on(WebSocketAction.Context, async (data, headers, callback) => {
if (data?.reset) {
await socket.reset();
}
if (Array.isArray(data?.context)) {
data.contexts = data?.context;
delete data.context;
}
if (!data?.exit) {
if (data?.contexts) {
for (const context of data.contexts) {
if (context) {
await socket.enter(context);
}
}
} else if (data?.context) {
await socket.enter(data.context);
}
} else {
if (data?.contexts) {
for (const context of data.contexts) {
if (context) {
await socket.exit(context);
}
}
} else if (data?.context) {
await socket.exit(data.context);
}
}
if (service.operations[WebSocketAction.Context]) {
await processEvent(socket, service, WebSocketAction.Context, data, headers, callback);
} else {
callback && (await callback());
}
});
}
function bindServiceOperations(socket, service) {
for (const operation of service.operations()) {
const event = serviceLocalName(service, operation.name);
if (Object.values(WebSocketAction).includes(event)) {
continue;
}
socket.on(event, async (data, headers, callback) => {
await processEvent(socket, service, event, data, headers, callback);
});
}
}
function bindServiceEntities(socket, service) {
for (const entity of service.entities()) {
const localEntityName = serviceLocalName(service, entity.name);
socket.on(`${localEntityName}:create`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "create", data, headers, async (response) => {
callback && (await callback(response));
await broadcastEvent(socket, service, `${localEntityName}:created`, entity, response);
});
});
socket.on(`${localEntityName}:read`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "read", data, headers, callback);
});
socket.on(`${localEntityName}:readDeep`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "readDeep", data, headers, callback);
});
socket.on(`${localEntityName}:update`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "update", data, headers, async (response) => {
callback && (await callback(response));
await broadcastEvent(socket, service, `${localEntityName}:updated`, entity, response);
});
});
socket.on(`${localEntityName}:delete`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "delete", data, headers, async (response) => {
callback && (await callback(response));
await broadcastEvent(socket, service, `${localEntityName}:deleted`, entity, { ...response, ...data });
});
});
socket.on(`${localEntityName}:list`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, "list", data, headers, callback);
});
for (const actionName in entity.actions) {
const action = entity.actions[actionName];
socket.on(`${localEntityName}:${action.name}`, async (data, headers, callback) => {
await processCRUD(socket, service, entity, action.name, data, headers, callback);
});
}
}
}
async function emitConnect(socket, service) {
if (service.operations[WebSocketAction.Connect]) {
await processEvent(socket, service, WebSocketAction.Connect);
}
}
async function processEvent(socket, service, event, data, headers, callback) {
try {
const response = await callEvent(socket, service, event, data, headers);
callback && (await callback(response));
} catch (err) {
LOG?.error(err);
try {
callback &&
(await callback({
error: {
code: err.code || err.status || err.statusCode,
message: err.message,
},
}));
} catch (err) {
LOG?.error(err);
}
}
}
async function callEvent(socket, service, event, data, headers) {
data = data || {};
return await service.tx(socket.context, async (srv) => {
return await srv.send({
event,
data,
headers,
});
});
}
async function processCRUD(socket, service, entity, event, data, headers, callback) {
try {
const response = await callCRUD(socket, service, entity, event, data, headers);
callback && (await callback(response));
} catch (err) {
LOG?.error(err);
try {
callback &&
(await callback({
error: {
code: err.code || err.status || err.statusCode,
message: err.message,
},
}));
} catch (err) {
LOG?.error(err);
}
}
}
async function callCRUD(socket, service, entity, event, data, headers) {
data = data || {};
return await service.tx(socket.context, async (srv) => {
const key = deriveKey(entity, data);
switch (event) {
case "create":
return await srv.send({ query: srv.create(entity).entries(data), headers });
case "read":
return await srv.send({ query: SELECT.one.from(entity).where(key), headers });
case "readDeep":
return await srv.send({
query: SELECT.one.from(entity).columns(getDeepEntityColumns(entity)).where(key),
headers,
});
case "update":
return await srv.send({ query: srv.update(entity).set(data).where(key), headers });
case "delete":
return await srv.send({ query: srv.delete(entity).where(key), headers });
case "list":
return await srv.send({ query: srv.read(entity).where(data), headers });
default:
return await srv.send({
event,
entity: entity.name,
data,
headers,
});
}
});
}
async function broadcastEvent(socket, service, event, entity, data, headers) {
let user;
let context;
let identifier;
const events = service.events();
const eventDefinition = events[event] || events[event.replaceAll(/:/g, ".")];
if (eventDefinition) {
user = deriveUser(eventDefinition, data, headers, socket);
context = deriveContext(eventDefinition, data, headers);
identifier = deriveIdentifier(eventDefinition, data, headers);
}
const contentData = broadcastData(entity, data, headers, eventDefinition);
if (contentData) {
if (entity["@websocket.broadcast.all"] || entity["@ws.broadcast.all"]) {
await socket.broadcastAll(event, contentData, user, context, identifier, headers);
} else {
await socket.broadcast(event, contentData, user, context, identifier, headers);
}
}
}
function broadcastData(entity, data, headers, event) {
if (event) {
return deriveElements(event, data, headers);
}
const content =
entity["@websocket.broadcast.content"] ||
entity["@ws.broadcast.content"] ||
entity["@websocket.broadcast"] ||
entity["@ws.broadcast"];
switch (content) {
case "key":
default:
return deriveKey(entity, data, headers);
case "data":
return data;
case "none":
return;
}
}
function deriveFormat(service, event) {
return (
event["@websocket.format"] ||
event["@ws.format"] ||
service.definition["@websocket.format"] ||
service.definition["@ws.format"] ||
"json"
);
}
function deriveKey(entity, data, headers) {
return Object.keys(entity.keys).reduce((result, key) => {
result[key] = data[key];
return result;
}, {});
}
function deriveElements(event, data, headers) {
return Object.keys(event.elements).reduce((result, element) => {
result[element] = data[element];
return result;
}, {});
}
function deriveUser(event, data, headers, req) {
const currentUser = deriveCurrentUser(event, data, headers, req);
const providedUser = deriveDefinedUser(event, data, headers, req);
return combineEntries(currentUser, providedUser);
}
function deriveCurrentUser(event, data, headers, req) {
const include = combineValues(
deriveValues(event, data, headers, {
headerNames: ["wsCurrentUser", "currentUser"],
annotationNames: [
"@websocket.currentUser",
"@ws.currentUser",
"@websocket.broadcast.currentUser",
"@ws.broadcast.currentUser",
],
resultValue: req.context.user?.id,
}),
deriveValues(event, data, headers, {
annotationNames: ["@websocket.user", "@ws.user", "@websocket.broadcast.user", "@ws.broadcast.user"],
annotationCompareValue: "includeCurrent",
resultValue: req.context.user?.id,
}),
deriveValues(event, data, headers, {
headerNames: ["wsCurrentUser.include", "wsCurrentUserInclude", "currentUser.include", "currentUserInclude"],
annotationNames: [
"@websocket.currentUser.include",
"@ws.currentUser.include",
"@websocket.broadcast.currentUser.include",
"@ws.broadcast.currentUser.include",
],
annotationCompareValue: true,
resultValue: req.context.user?.id,
}),
);
const exclude = combineValues(
deriveValues(event, data, headers, {
annotationNames: ["@websocket.user", "@ws.user", "@websocket.broadcast.user", "@ws.broadcast.user"],
annotationCompareValue: "excludeCurrent",
resultValue: req.context.user?.id,
}),
deriveValues(event, data, headers, {
headerNames: ["wsCurrentUser.exclude", "wsCurrentUserExclude", "currentUser.exclude", "currentUserExclude"],
annotationNames: [
"@websocket.currentUser.exclude",
"@ws.currentUser.exclude",
"@websocket.broadcast.currentUser.exclude",
"@ws.broadcast.currentUser.exclude",
],
annotationCompareValue: true,
resultValue: req.context.user?.id,
}),
);
if (include || exclude) {
return { include, exclude };
}
}
function deriveDefinedUser(event, data, headers) {
const include = combineValues(
deriveValues(event, data, headers, {
headerNames: ["wsUsers", "wsUser", "users", "user"],
annotationNames: ["@websocket.user", "@ws.user", "@websocket.broadcast.user", "@ws.broadcast.user"],
annotationExcludeValues: ["includeCurrent", "excludeCurrent"],
}),
deriveValues(event, data, headers, {
headerNames: ["wsUser.include", "wsUserInclude", "user.include", "userInclude"],
annotationNames: [
"@websocket.user.include",
"@ws.user.include",
"@websocket.broadcast.user.include",
"@ws.broadcast.user.include",
],
}),
);
const exclude = deriveValues(event, data, headers, {
headerNames: ["wsUser.exclude", "wsUserExclude", "user.exclude", "userExclude"],
annotationNames: [
"@websocket.user.exclude",
"@ws.user.exclude",
"@websocket.broadcast.user.exclude",
"@ws.broadcast.user.exclude",
],
});
if (include || exclude) {
return { include, exclude };
}
}
function deriveContext(event, data, headers) {
const include = combineValues(
deriveValues(event, data, headers, {
headerNames: ["wsContexts", "wsContext", "contexts", "context"],
annotationNames: ["@websocket.context", "@ws.context", "@websocket.broadcast.context", "@ws.broadcast.context"],
}),
deriveValues(event, data, headers, {
headerNames: ["wsContext.include", "wsContextInclude", "context.include", "contextInclude"],
annotationNames: [
"@websocket.context.include",
"@ws.context.include",
"@websocket.broadcast.context.include",
"@ws.broadcast.context.include",
],
}),
);
const exclude = deriveValues(event, data, headers, {
headerNames: ["wsContext.exclude", "wsContextExclude", "context.exclude", "contextExclude"],
annotationNames: [
"@websocket.context.exclude",
"@ws.context.exclude",
"@websocket.broadcast.context.exclude",
"@ws.broadcast.context.exclude",
],
});
if (include || exclude) {
return { include, exclude };
}
}
function deriveIdentifier(event, data, headers) {
const include = combineValues(
deriveValues(event, data, headers, {
headerNames: ["wsIdentifiers", "wsIdentifier", "identifiers", "identifier"],
annotationNames: [
"@websocket.identifier",
"@ws.identifier",
"@websocket.broadcast.identifier",
"@ws.broadcast.identifier",
],
}),
deriveValues(event, data, headers, {
headerNames: ["wsIdentifier.include", "wsIdentifierInclude", "identifier.include", "identifierInclude"],
annotationNames: [
"@websocket.identifier.include",
"@ws.identifier.include",
"@websocket.broadcast.identifier.include",
"@ws.broadcast.identifier.include",
],
}),
);
const exclude = deriveValues(event, data, headers, {
headerNames: ["wsIdentifier.exclude", "wsIdentifierExclude", "identifier.exclude", "identifierExclude"],
annotationNames: [
"@websocket.identifier.exclude",
"@ws.identifier.exclude",
"@websocket.broadcast.identifier.exclude",
"@ws.broadcast.identifier.exclude",
],
});
if (include || exclude) {
return { include, exclude };
}
}
function deriveValues(
event,
data,
headers,
{ headerNames, annotationNames, headerCompareValue, annotationCompareValue, resultValue, annotationExcludeValues },
) {
let result = undefined;
if (data) {
for (const annotationName of annotationNames || []) {
const annotationValue = event[annotationName];
if (annotationExcludeValues?.includes(annotationValue)) {
continue;
}
if (annotationCompareValue === undefined ? annotationValue : annotationValue === annotationCompareValue) {
result = mergeValue(result, resultValue ?? annotationValue);
}
if (event.elements) {
for (const name in event.elements) {
const element = event.elements[name];
if (element["@websocket.ignore"] || element["@ws.ignore"]) {
continue;
}
const annotationValue = element[annotationName];
if (annotationExcludeValues?.includes(annotationValue)) {
continue;
}
if (annotationCompareValue === undefined ? annotationValue : annotationValue === annotationCompareValue) {
if (resultValue === undefined) {
result = mergeValue(result, data[name]);
} else if (data[name]) {
result = mergeValue(result, resultValue);
}
}
}
}
}
}
if (headers) {
for (const headerName of headerNames || []) {
const headerValue = accessPath(headers, headerName);
if (headerValue?.constructor === Object) {
continue;
}
if (headerCompareValue === undefined ? headerValue : headerValue === headerCompareValue) {
result = mergeValue(result, resultValue ?? headerValue);
}
}
}
return removeArrayDuplicates(result);
}
function combineEntries(entryA, entryB) {
let include = combineValues(entryA?.include, entryB?.include);
let exclude = combineValues(entryA?.exclude, entryB?.exclude);
if (include || exclude) {
return { include, exclude };
}
}
function combineValues(...values) {
let result = undefined;
for (const entry of values) {
if (entry !== undefined) {
result ??= [];
result = result.concat(entry);
}
}
return removeArrayDuplicates(result);
}
function accessPath(object, path) {
const properties = path.split(".");
for (let i = 0; i < properties.length; i++) {
if (!object) {
return null;
}
object = object[properties[i]];
}
return object;
}
function mergeValue(result, value) {
if (value === undefined || value === null) {
return result;
}
result ??= [];
if (Array.isArray(value)) {
for (const entry of value) {
result.push(stringValue(entry));
}
} else if (!(value instanceof Object)) {
result.push(stringValue(value));
}
return result;
}
function removeArrayDuplicates(array) {
if (!Array.isArray(array)) {
return array;
}
return [...new Set(array)];
}
function stringValue(value) {
if (value instanceof Date) {
return value.toISOString();
} else if (value instanceof Object) {
return JSON.stringify(value);
}
return String(value);
}
function parseStringValue(value) {
if (value === undefined || value === null || typeof value !== "string") {
return value;
}
if (["false", "true"].includes(value)) {
return value === "true";
}
if (!isNaN(value)) {
return parseFloat(value);
}
return value;
}
function deriveHeaders(headers, format) {
for (const header in headers ?? {}) {
let xHeader = header.toLocaleLowerCase();
if (!xHeader.startsWith("x-websocket-") && !xHeader.startsWith("x-ws-")) {
continue;
}
if (header.toLocaleLowerCase().startsWith("x-websocket-")) {
xHeader = xHeader.substring("x-websocket-".length);
} else if (xHeader.startsWith("x-ws-")) {
xHeader = xHeader.substring("x-ws-".length);
}
let formatSpecific = false;
if (xHeader.startsWith(`${format}-`)) {
xHeader = xHeader.substring(`${format}-`.length);
formatSpecific = true;
}
const value = parseStringValue(headers[header]);
delete headers[header];
if (formatSpecific) {
headers.ws ??= {};
headers.ws[format] ??= {};
headers.ws[format][xHeader] = value;
headers.ws[format][toCamelCase(xHeader)] = value;
} else {
headers[xHeader] = value;
headers[toCamelCase(xHeader)] = value;
}
}
return headers;
}
function deriveEventHeaders(headers) {
return headers?.websocket || headers?.ws ? { ...headers?.websocket, ...headers?.ws } : undefined;
}
function derivePath(event, path) {
return event["@websocket.path"] || event["@ws.path"] || path;
}
function getDeepEntityColumns(entity) {
const columns = [];
for (const element of Object.values(entity.elements)) {
if (element.type === "cds.Composition" && element.target) {
columns.push({
ref: [element.name],
expand: getDeepEntityColumns(element._target),
});
} else {
columns.push({
ref: [element.name],
});
}
}
return columns;
}
function serviceLocalName(service, name) {
const servicePrefix = `${service.name}.`;
if (name.startsWith(servicePrefix)) {
return name.substring(servicePrefix.length);
}
return name;
}
function toCamelCase(string) {
return string.replace(/([-_][a-z])/g, (group) => group.toUpperCase().replace("-", "").replace("_", ""));
}
function interableObject(object) {
return {
...object,
[Symbol.iterator]: function* () {
for (const event in this) {
yield this[event];
}
},
};
}
function isServedViaWebsocket(service) {
if (!service) {
return false;
}
const serviceDefinition = service.definition;
if (!serviceDefinition) {
return false;
}
let protocols = serviceDefinition["@protocol"];
if (protocols) {
protocols = !Array.isArray(protocols) ? [protocols] : protocols;
return protocols.some((protocol) => {
return ["websocket", "ws"].includes(typeof protocol === "string" ? protocol : protocol.kind);
});
}
const protocolDirect = Object.keys(cds.env.protocols || {}).find((protocol) => serviceDefinition["@" + protocol]);
if (protocolDirect) {
return ["websocket", "ws"].includes(protocolDirect);
}
return false;
}
module.exports = collectServicesAndMountAdapter;