UNPKG

@rivetkit/redis

Version:

_Lightweight Libraries for Backends_

425 lines (372 loc) 10.6 kB
import { toUint8Array } from "@rivetkit/core"; import { logger } from "../log"; import type { RelayConn } from "../relay-conn"; import { LeaderChangedError } from "./message"; import type { Node } from "./mod"; import type { NodeMessage } from "./protocol"; export class RelayWebSocketAdapter implements WebSocket { #node: Node; #websocketId: string; #relayConn: RelayConn; #readyState: number = WebSocket.CONNECTING; #eventListeners: Map<string, Set<(event: any) => void>> = new Map(); #onopen: ((this: WebSocket, ev: Event) => any) | null = null; #onclose: ((this: WebSocket, ev: any) => any) | null = null; #onerror: ((this: WebSocket, ev: Event) => any) | null = null; #onmessage: ((this: WebSocket, ev: MessageEvent) => any) | null = null; #bufferedAmount = 0; #binaryType: "blob" | "arraybuffer" = "blob"; #extensions = ""; #protocol = ""; #url = ""; #openPromise: Promise<void>; #openResolve!: () => void; // Event buffering is needed since events can be fired // before JavaScript has a chance to add event listeners (e.g. within the same tick) #bufferedEvents: Array<{ type: string; event: any; }> = []; constructor(node: Node, websocketId: string, relayConn: RelayConn) { this.#node = node; this.#websocketId = websocketId; this.#relayConn = relayConn; // Create open promise this.#openPromise = new Promise<void>((resolve) => { this.#openResolve = resolve; }); // Register this WebSocket with the global state (this.#node.globalState as any).relayWebSockets = (this.#node.globalState as any).relayWebSockets || new Map(); (this.#node.globalState as any).relayWebSockets.set(websocketId, this); logger().debug("relay websocket adapter registered", { websocketId, nodeId: (this.#node.globalState as any).nodeId, relayWebSocketsSize: (this.#node.globalState as any).relayWebSockets.size, }); } get openPromise(): Promise<void> { return this.#openPromise; } get readyState(): number { return this.#readyState; } get bufferedAmount(): number { return this.#bufferedAmount; } get binaryType(): "blob" | "arraybuffer" { return this.#binaryType; } set binaryType(value: "blob" | "arraybuffer") { this.#binaryType = value; } get extensions(): string { return this.#extensions; } get protocol(): string { return this.#protocol; } get url(): string { return this.#url; } get actorId(): string { return this.#relayConn.actorId; } get onopen(): ((this: WebSocket, ev: Event) => any) | null { return this.#onopen; } set onopen(value: ((this: WebSocket, ev: Event) => any) | null) { this.#onopen = value; // Flush any buffered open events when onopen is set if (value) { this.#flushBufferedEvents("open"); } } get onclose(): ((this: WebSocket, ev: any) => any) | null { return this.#onclose; } set onclose(value: ((this: WebSocket, ev: any) => any) | null) { this.#onclose = value; // Flush any buffered close events when onclose is set if (value) { this.#flushBufferedEvents("close"); } } get onerror(): ((this: WebSocket, ev: Event) => any) | null { return this.#onerror; } set onerror(value: ((this: WebSocket, ev: Event) => any) | null) { this.#onerror = value; // Flush any buffered error events when onerror is set if (value) { this.#flushBufferedEvents("error"); } } get onmessage(): ((this: WebSocket, ev: MessageEvent) => any) | null { return this.#onmessage; } set onmessage(value: ((this: WebSocket, ev: MessageEvent) => any) | null) { this.#onmessage = value; // Flush any buffered message events when onmessage is set if (value) { this.#flushBufferedEvents("message"); } } send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { if (this.#readyState !== WebSocket.OPEN) { throw new DOMException("WebSocket is not open"); } // Convert data to appropriate format let isBinary = false; let messageData: string | Uint8Array; if (typeof data === "string") { messageData = data; } else if (data instanceof ArrayBuffer || ArrayBuffer.isView(data)) { isBinary = true; messageData = toUint8Array(data); } else if (data instanceof Blob) { throw new Error("Blob sending not implemented in relay adapter"); } else { throw new Error("Invalid data type"); } // Send message to leader using no-retry version const message: NodeMessage = { b: { lwm: { wi: this.#websocketId, data: messageData, binary: isBinary, }, }, }; this.#relayConn .publishMessageToleader(message, false) .catch((error: unknown) => { // Handle leader change by closing the WebSocket if (error instanceof LeaderChangedError) { this._handleClose(1001, "Actor leader changed"); } else { const event = new Event("error"); this.#fireEvent("error", event); } }); } close(code?: number, reason?: string): void { if ( this.#readyState === WebSocket.CLOSING || this.#readyState === WebSocket.CLOSED ) { return; } this.#readyState = WebSocket.CLOSING; // Send close message to leader this.#relayConn .disconnect(false, "Client closed WebSocket", { b: { lwc: { wi: this.#websocketId, code, reason, }, }, }) .finally(() => { this.#readyState = WebSocket.CLOSED; (this.#node.globalState as any).relayWebSockets?.delete( this.#websocketId, ); const event = { type: "close", target: this, code: code || 1000, reason: reason || "", wasClean: true, }; this.#fireEvent("close", event); }); } addEventListener(type: string, listener: any, options?: boolean | any): void { if (typeof listener === "function") { let listeners = this.#eventListeners.get(type); if (!listeners) { listeners = new Set(); this.#eventListeners.set(type, listeners); } listeners.add(listener); // Flush any buffered events for this type logger().debug(`flushing buffered events for ${type}`, { websocketId: this.#websocketId, bufferedEventsCount: this.#bufferedEvents.filter((e) => e.type === type) .length, }); this.#flushBufferedEvents(type); } } removeEventListener( type: string, listener: any, options?: boolean | any, ): void { if (typeof listener === "function") { const listeners = this.#eventListeners.get(type); if (listeners) { listeners.delete(listener); } } } dispatchEvent(event: Event): boolean { // Simple implementation return true; } #fireEvent(type: string, event: any): void { // Call all registered event listeners const listeners = this.#eventListeners.get(type); let hasListeners = false; if (listeners && listeners.size > 0) { hasListeners = true; for (const listener of listeners) { try { listener.call(this, event); } catch (error) { logger().error("error in websocket event listener", { error, type }); } } } // Call the onX property if set switch (type) { case "open": if (this.#onopen) { hasListeners = true; try { this.#onopen.call(this, event); } catch (error) { logger().error("error in onopen handler", { error }); } } break; case "close": if (this.#onclose) { hasListeners = true; try { this.#onclose.call(this, event); } catch (error) { logger().error("error in onclose handler", { error }); } } break; case "error": if (this.#onerror) { hasListeners = true; try { this.#onerror.call(this, event); } catch (error) { logger().error("error in onerror handler", { error }); } } break; case "message": if (this.#onmessage) { hasListeners = true; try { this.#onmessage.call(this, event); } catch (error) { logger().error("error in onmessage handler", { error }); } } break; } // Buffer the event if no listeners are registered if (!hasListeners) { logger().debug(`no ${type} listeners registered, buffering event`); this.#bufferedEvents.push({ type, event }); } } #flushBufferedEvents(type: string): void { const eventsToFlush = this.#bufferedEvents.filter( (buffered) => buffered.type === type, ); this.#bufferedEvents = this.#bufferedEvents.filter( (buffered) => buffered.type !== type, ); for (const { event } of eventsToFlush) { // Re-fire the event, which will now have listeners const listeners = this.#eventListeners.get(type); if (listeners) { for (const listener of listeners) { try { listener.call(this, event); } catch (error) { logger().error("error in websocket event listener", { error, type, }); } } } } } // Internal method to handle incoming messages from leader _handleMessage(data: string | Uint8Array, isBinary: boolean): void { if (this.#readyState !== WebSocket.OPEN) { return; } let messageData: Uint8Array | string; if (isBinary) { // Handle binary data - should always be Uint8Array if (data instanceof Uint8Array) { messageData = data; } else { throw new Error("Binary data must be Uint8Array"); } } else { messageData = data; } const event = new MessageEvent("message", { data: messageData, origin: "", lastEventId: "", }); this.#fireEvent("message", event); } // Internal method to handle open confirmation from leader _handleOpen(): void { logger().debug("_handleOpen called", { websocketId: this.#websocketId, currentReadyState: this.#readyState, isConnecting: this.#readyState === WebSocket.CONNECTING, }); if (this.#readyState !== WebSocket.CONNECTING) { return; } this.#readyState = WebSocket.OPEN; this.#openResolve(); const event = new Event("open"); this.#fireEvent("open", event); } // Internal method to handle close from leader _handleClose(code?: number, reason?: string): void { if (this.#readyState === WebSocket.CLOSED) { return; } this.#readyState = WebSocket.CLOSED; (this.#node.globalState as any).relayWebSockets?.delete(this.#websocketId); const event = { type: "close", target: this, code: code || 1000, reason: reason || "", wasClean: true, }; this.#fireEvent("close", event); } // Required WebSocket constants static readonly CONNECTING = 0; static readonly OPEN = 1; static readonly CLOSING = 2; static readonly CLOSED = 3; // Instance constants readonly CONNECTING = 0; readonly OPEN = 1; readonly CLOSING = 2; readonly CLOSED = 3; }