UNPKG

@push.rocks/smartproxy

Version:

A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.

505 lines 48.9 kB
import * as plugins from '../../plugins.js'; import '../../core/models/socket-augmentation.js'; import { createLogger } from './models/types.js'; import { ConnectionPool } from './connection-pool.js'; import { HttpRouter } from '../../routing/router/index.js'; import { toBaseContext } from '../../core/models/route-context.js'; import { ContextCreator } from './context-creator.js'; import { SecurityManager } from './security-manager.js'; import { TemplateUtils } from '../../core/utils/template-utils.js'; import { getMessageSize, toBuffer } from '../../core/utils/websocket-utils.js'; /** * Handles WebSocket connections and proxying */ export class WebSocketHandler { constructor(options, connectionPool, routes = []) { this.options = options; this.connectionPool = connectionPool; this.routes = routes; this.heartbeatInterval = null; this.wsServer = null; this.contextCreator = new ContextCreator(); this.router = null; this.logger = createLogger(options.logLevel || 'info'); this.securityManager = new SecurityManager(this.logger, routes); // Initialize router if we have routes if (routes.length > 0) { this.router = new HttpRouter(routes, this.logger); } } /** * Set the route configurations */ setRoutes(routes) { this.routes = routes; // Initialize or update the route router if (!this.router) { this.router = new HttpRouter(routes, this.logger); } else { this.router.setRoutes(routes); } // Update the security manager this.securityManager.setRoutes(routes); } /** * Select the appropriate target from the targets array based on sub-matching criteria */ selectTarget(targets, context) { // Sort targets by priority (higher first) const sortedTargets = [...targets].sort((a, b) => (b.priority || 0) - (a.priority || 0)); // Find the first matching target for (const target of sortedTargets) { if (!target.match) { // No match criteria means this is a default/fallback target return target; } // Check port match if (target.match.ports && !target.match.ports.includes(context.port)) { continue; } // Check path match (supports wildcards) if (target.match.path && context.path) { const pathPattern = target.match.path.replace(/\*/g, '.*'); const pathRegex = new RegExp(`^${pathPattern}$`); if (!pathRegex.test(context.path)) { continue; } } // Check method match if (target.match.method && context.method && !target.match.method.includes(context.method)) { continue; } // Check headers match if (target.match.headers && context.headers) { let headersMatch = true; for (const [key, pattern] of Object.entries(target.match.headers)) { const headerValue = context.headers[key.toLowerCase()]; if (!headerValue) { headersMatch = false; break; } if (pattern instanceof RegExp) { if (!pattern.test(headerValue)) { headersMatch = false; break; } } else if (headerValue !== pattern) { headersMatch = false; break; } } if (!headersMatch) { continue; } } // All criteria matched return target; } // No matching target found, return the first target without match criteria (default) return sortedTargets.find(t => !t.match) || null; } /** * Initialize WebSocket server on an existing HTTPS server */ initialize(server) { // Create WebSocket server this.wsServer = new plugins.ws.WebSocketServer({ server: server, clientTracking: true }); // Handle WebSocket connections this.wsServer.on('connection', (wsIncoming, req) => { this.handleWebSocketConnection(wsIncoming, req); }); // Start the heartbeat interval this.startHeartbeat(); this.logger.info('WebSocket handler initialized'); } /** * Start the heartbeat interval to check for inactive WebSocket connections */ startHeartbeat() { // Clean up existing interval if any if (this.heartbeatInterval) { clearInterval(this.heartbeatInterval); } // Set up the heartbeat interval (check every 30 seconds) this.heartbeatInterval = setInterval(() => { if (!this.wsServer || this.wsServer.clients.size === 0) { return; // Skip if no active connections } this.logger.debug(`WebSocket heartbeat check for ${this.wsServer.clients.size} clients`); this.wsServer.clients.forEach((ws) => { const wsWithHeartbeat = ws; if (wsWithHeartbeat.isAlive === false) { this.logger.debug('Terminating inactive WebSocket connection'); return wsWithHeartbeat.terminate(); } wsWithHeartbeat.isAlive = false; wsWithHeartbeat.ping(); }); }, 30000); // Make sure the interval doesn't keep the process alive if (this.heartbeatInterval.unref) { this.heartbeatInterval.unref(); } } /** * Handle a new WebSocket connection */ handleWebSocketConnection(wsIncoming, req) { this.logger.debug(`WebSocket connection initiated from ${req.headers.host}`); try { // Initialize heartbeat tracking wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); // Handle pong messages to track liveness wsIncoming.on('pong', () => { wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); }); // Create a context for routing const connectionId = `ws-${Date.now()}-${Math.floor(Math.random() * 10000)}`; const routeContext = this.contextCreator.createHttpRouteContext(req, { connectionId, clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', tlsVersion: req.socket.getTLSVersion?.() || undefined }); // Try modern router first if available let route; if (this.router) { route = this.router.routeReq(req); } // Define destination variables let destination; // If we found a route with the modern router, use it if (route && route.action.type === 'forward' && route.action.targets && route.action.targets.length > 0) { this.logger.debug(`Found matching WebSocket route: ${route.name || 'unnamed'}`); // Select the appropriate target from the targets array const selectedTarget = this.selectTarget(route.action.targets, { port: routeContext.port, path: routeContext.path, headers: routeContext.headers, method: routeContext.method }); if (!selectedTarget) { this.logger.error(`No matching target found for route ${route.name}`); wsIncoming.close(1003, 'No matching target'); return; } // Check if WebSockets are enabled for this route if (route.action.websocket?.enabled === false) { this.logger.debug(`WebSockets are disabled for route: ${route.name || 'unnamed'}`); wsIncoming.close(1003, 'WebSockets not supported for this route'); return; } // Check security restrictions if configured to authenticate WebSocket requests if (route.action.websocket?.authenticateRequest !== false && route.security) { if (!this.securityManager.isAllowed(route, toBaseContext(routeContext))) { this.logger.warn(`WebSocket connection denied by security policy for ${routeContext.clientIp}`); wsIncoming.close(1008, 'Access denied by security policy'); return; } // Check origin restrictions if configured const origin = req.headers.origin; if (origin && route.action.websocket?.allowedOrigins && route.action.websocket.allowedOrigins.length > 0) { const isAllowed = route.action.websocket.allowedOrigins.some(allowedOrigin => { // Handle wildcards and template variables if (allowedOrigin.includes('*') || allowedOrigin.includes('{')) { const pattern = allowedOrigin.replace(/\*/g, '.*'); const resolvedPattern = TemplateUtils.resolveTemplateVariables(pattern, routeContext); const regex = new RegExp(`^${resolvedPattern}$`); return regex.test(origin); } return allowedOrigin === origin; }); if (!isAllowed) { this.logger.warn(`WebSocket origin ${origin} not allowed for route: ${route.name || 'unnamed'}`); wsIncoming.close(1008, 'Origin not allowed'); return; } } } // Extract target information, resolving functions if needed let targetHost; let targetPort; try { // Resolve host if it's a function if (typeof selectedTarget.host === 'function') { const resolvedHost = selectedTarget.host(toBaseContext(routeContext)); targetHost = resolvedHost; this.logger.debug(`Resolved function-based host for WebSocket: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); } else { targetHost = selectedTarget.host; } // Resolve port if it's a function if (typeof selectedTarget.port === 'function') { targetPort = selectedTarget.port(toBaseContext(routeContext)); this.logger.debug(`Resolved function-based port for WebSocket: ${targetPort}`); } else { targetPort = selectedTarget.port === 'preserve' ? routeContext.port : selectedTarget.port; } // Select a single host if an array was provided const selectedHost = Array.isArray(targetHost) ? targetHost[Math.floor(Math.random() * targetHost.length)] : targetHost; // Create a destination for the WebSocket connection destination = { host: selectedHost, port: targetPort }; this.logger.debug(`WebSocket destination resolved: ${selectedHost}:${targetPort}`); } catch (err) { this.logger.error(`Error evaluating function-based target for WebSocket: ${err}`); wsIncoming.close(1011, 'Internal server error'); return; } } else { // No route found this.logger.warn(`No route configuration for WebSocket host: ${req.headers.host}`); wsIncoming.close(1008, 'No route configuration for this host'); return; } // Build target URL with potential path rewriting // Determine protocol based on the target's configuration // For WebSocket connections, we use ws for HTTP backends and wss for HTTPS backends const isTargetSecure = destination.port === 443; const protocol = isTargetSecure ? 'wss' : 'ws'; let targetPath = req.url || '/'; // Apply path rewriting if configured if (route?.action.websocket?.rewritePath) { const originalPath = targetPath; targetPath = TemplateUtils.resolveTemplateVariables(route.action.websocket.rewritePath, { ...routeContext, path: targetPath }); this.logger.debug(`WebSocket path rewritten: ${originalPath} -> ${targetPath}`); } const targetUrl = `${protocol}://${destination.host}:${destination.port}${targetPath}`; this.logger.debug(`WebSocket connection from ${req.socket.remoteAddress} to ${targetUrl}`); // Create headers for outgoing WebSocket connection const headers = {}; // Copy relevant headers from incoming request for (const [key, value] of Object.entries(req.headers)) { if (value && typeof value === 'string' && key.toLowerCase() !== 'connection' && key.toLowerCase() !== 'upgrade' && key.toLowerCase() !== 'sec-websocket-key' && key.toLowerCase() !== 'sec-websocket-version') { headers[key] = value; } } // Always rewrite host header for WebSockets for consistency headers['host'] = `${destination.host}:${destination.port}`; // Add custom headers from route configuration if (route?.action.websocket?.customHeaders) { for (const [key, value] of Object.entries(route.action.websocket.customHeaders)) { // Skip if header already exists and we're not overriding if (headers[key.toLowerCase()] && !value.startsWith('!')) { continue; } // Handle special delete directive (!delete) if (value === '!delete') { delete headers[key.toLowerCase()]; continue; } // Handle forced override (!value) let finalValue; if (value.startsWith('!') && value !== '!delete') { // Keep the ! but resolve any templates in the rest const templateValue = value.substring(1); finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext); } else { // Resolve templates in the entire value finalValue = TemplateUtils.resolveTemplateVariables(value, routeContext); } // Set the header headers[key.toLowerCase()] = finalValue; } } // Create WebSocket connection options const wsOptions = { headers: headers, followRedirects: true }; // Add subprotocols if configured if (route?.action.websocket?.subprotocols && route.action.websocket.subprotocols.length > 0) { wsOptions.protocols = route.action.websocket.subprotocols; } else if (req.headers['sec-websocket-protocol']) { // Pass through client requested protocols wsOptions.protocols = req.headers['sec-websocket-protocol'].split(',').map(p => p.trim()); } // Create outgoing WebSocket connection this.logger.debug(`Creating WebSocket connection to ${targetUrl} with options:`, { headers: wsOptions.headers, protocols: wsOptions.protocols }); const wsOutgoing = new plugins.wsDefault(targetUrl, wsOptions); this.logger.debug(`WebSocket instance created, waiting for connection...`); // Handle connection errors wsOutgoing.on('error', (err) => { this.logger.error(`WebSocket target connection error: ${err.message}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.close(1011, 'Internal server error'); } }); // Handle outgoing connection open wsOutgoing.on('open', () => { this.logger.debug(`WebSocket target connection opened to ${targetUrl}`); // Set up custom ping interval if configured let pingInterval = null; if (route?.action.websocket?.pingInterval && route.action.websocket.pingInterval > 0) { pingInterval = setInterval(() => { if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.ping(); this.logger.debug(`Sent WebSocket ping to client for route: ${route.name || 'unnamed'}`); } }, route.action.websocket.pingInterval); // Don't keep process alive just for pings if (pingInterval.unref) pingInterval.unref(); } // Set up custom ping timeout if configured let pingTimeout = null; const pingTimeoutMs = route?.action.websocket?.pingTimeout || 60000; // Default 60s // Define timeout function for cleaner code const resetPingTimeout = () => { if (pingTimeout) clearTimeout(pingTimeout); pingTimeout = setTimeout(() => { this.logger.debug(`WebSocket ping timeout for client connection on route: ${route?.name || 'unnamed'}`); wsIncoming.terminate(); }, pingTimeoutMs); // Don't keep process alive just for timeouts if (pingTimeout.unref) pingTimeout.unref(); }; // Reset timeout on pong wsIncoming.on('pong', () => { wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); resetPingTimeout(); }); // Initial ping timeout resetPingTimeout(); // Handle potential message size limits const maxSize = route?.action.websocket?.maxPayloadSize || 0; // Forward incoming messages to outgoing connection wsIncoming.on('message', (data, isBinary) => { this.logger.debug(`WebSocket forwarding message from client to target: ${data.toString()}`); if (wsOutgoing.readyState === wsOutgoing.OPEN) { // Check message size if limit is set const messageSize = getMessageSize(data); if (maxSize > 0 && messageSize > maxSize) { this.logger.warn(`WebSocket message exceeds max size (${messageSize} > ${maxSize})`); wsIncoming.close(1009, 'Message too big'); return; } wsOutgoing.send(data, { binary: isBinary }); } else { this.logger.warn(`WebSocket target connection not open (state: ${wsOutgoing.readyState})`); } }); // Forward outgoing messages to incoming connection wsOutgoing.on('message', (data, isBinary) => { this.logger.debug(`WebSocket forwarding message from target to client: ${data.toString()}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.send(data, { binary: isBinary }); } else { this.logger.warn(`WebSocket client connection not open (state: ${wsIncoming.readyState})`); } }); // Handle closing of connections wsIncoming.on('close', (code, reason) => { this.logger.debug(`WebSocket client connection closed: ${code} ${reason}`); if (wsOutgoing.readyState === wsOutgoing.OPEN) { // Ensure code is a valid WebSocket close code number const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; try { const reasonString = reason ? toBuffer(reason).toString() : ''; wsOutgoing.close(validCode, reasonString); } catch (err) { this.logger.error('Error closing wsOutgoing:', err); wsOutgoing.close(validCode); } } // Clean up timers if (pingInterval) clearInterval(pingInterval); if (pingTimeout) clearTimeout(pingTimeout); }); wsOutgoing.on('close', (code, reason) => { this.logger.debug(`WebSocket target connection closed: ${code} ${reason}`); if (wsIncoming.readyState === wsIncoming.OPEN) { // Ensure code is a valid WebSocket close code number const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; try { const reasonString = reason ? toBuffer(reason).toString() : ''; wsIncoming.close(validCode, reasonString); } catch (err) { this.logger.error('Error closing wsIncoming:', err); wsIncoming.close(validCode); } } // Clean up timers if (pingInterval) clearInterval(pingInterval); if (pingTimeout) clearTimeout(pingTimeout); }); this.logger.debug(`WebSocket connection established: ${req.headers.host} -> ${destination.host}:${destination.port}`); }); } catch (error) { this.logger.error(`Error handling WebSocket connection: ${error.message}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.close(1011, 'Internal server error'); } } } /** * Get information about active WebSocket connections */ getConnectionInfo() { return { activeConnections: this.wsServer ? this.wsServer.clients.size : 0 }; } /** * Shutdown the WebSocket handler */ shutdown() { // Stop heartbeat interval if (this.heartbeatInterval) { clearInterval(this.heartbeatInterval); this.heartbeatInterval = null; } // Close all WebSocket connections if (this.wsServer) { this.logger.info(`Closing ${this.wsServer.clients.size} WebSocket connections`); for (const client of this.wsServer.clients) { try { client.terminate(); } catch (error) { this.logger.error('Error terminating WebSocket client', error); } } // Close the server this.wsServer.close(); this.wsServer = null; } } } //# sourceMappingURL=data:application/json;base64,