UNPKG

@redwoodjs/sdk

Version:

A full-stack webapp toolkit designed for TypeScript, Vite, and React Server Components

175 lines (174 loc) 6.77 kB
import { DurableObject } from "cloudflare:workers"; import { MESSAGE_TYPE } from "./shared"; import { validateUpgradeRequest } from "./validateUpgradeRequest"; export class RealtimeDurableObject extends DurableObject { constructor(state, env) { super(state, env); this.state = state; this.env = env; this.storage = state.storage; this.clientInfoCache = new Map(); } async fetch(request) { const validation = validateUpgradeRequest(request); if (!validation.valid) { return validation.response; } const url = new URL(request.url); const clientInfo = this.createClientInfo(url, request); return this.handleWebSocket(request, clientInfo); } createClientInfo(url, request) { return { url: url.searchParams.get("url"), clientId: url.searchParams.get("clientId"), cookieHeaders: request.headers.get("Cookie") || "", }; } async storeClientInfo(clientInfo) { this.clientInfoCache.set(clientInfo.clientId, clientInfo); await this.storage.put(`client:${clientInfo.clientId}`, clientInfo); } async getClientInfo(clientId) { const cachedInfo = this.clientInfoCache.get(clientId); if (cachedInfo) { return cachedInfo; } const clientInfo = await this.storage.get(`client:${clientId}`); if (!clientInfo) { throw new Error(`Client info not found for clientId: ${clientId}`); } this.clientInfoCache.set(clientId, clientInfo); return clientInfo; } async handleWebSocket(request, clientInfo) { const { 0: client, 1: server } = new WebSocketPair(); await this.storeClientInfo(clientInfo); server.serializeAttachment(clientInfo.clientId); this.state.acceptWebSocket(server); return new Response(null, { status: 101, webSocket: client }); } async webSocketMessage(ws, data) { const clientId = ws.deserializeAttachment(); const clientInfo = await this.getClientInfo(clientId); const message = new Uint8Array(data); const messageType = message[0]; if (messageType === MESSAGE_TYPE.ACTION_REQUEST) { const decoder = new TextDecoder(); const jsonData = decoder.decode(message.slice(1)); const { id, args } = JSON.parse(jsonData); try { await this.handleAction(ws, id, args, clientInfo); } catch (error) { const errorData = JSON.stringify({ id, error: error instanceof Error ? error.message : String(error), }); const encoder = new TextEncoder(); const errorBytes = encoder.encode(errorData); const errorResponse = new Uint8Array(errorBytes.length + 1); errorResponse[0] = MESSAGE_TYPE.ACTION_ERROR; errorResponse.set(errorBytes, 1); ws.send(errorResponse); } } } async streamResponse(response, ws, messageTypes) { const reader = response.body.getReader(); try { while (true) { const { done, value } = await reader.read(); if (done) { ws.send(new Uint8Array([messageTypes.end])); break; } const chunkMessage = new Uint8Array(value.length + 1); chunkMessage[0] = messageTypes.chunk; chunkMessage.set(value, 1); ws.send(chunkMessage); } } finally { reader.releaseLock(); } } async handleAction(ws, id, args, clientInfo) { const url = new URL(clientInfo.url); url.searchParams.set("__rsc", "true"); url.searchParams.set("__rsc_action_id", id); const response = await fetch(url.toString(), { method: "POST", body: args, headers: { "Content-Type": "application/json", Cookie: clientInfo.cookieHeaders, }, }); if (!response.ok) { throw new Error(`Action failed: ${response.statusText}`); } this.render({ exclude: [clientInfo.clientId], }); await this.streamResponse(response, ws, { chunk: MESSAGE_TYPE.ACTION_CHUNK, end: MESSAGE_TYPE.ACTION_END, }); } async determineSockets({ include = [], exclude = [], } = {}) { const sockets = Array.from(this.state.getWebSockets()); const includeSet = include.length > 0 ? new Set(include) : null; const excludeSet = exclude.length > 0 ? new Set(exclude) : null; const results = []; for (const socket of sockets) { const clientId = socket.deserializeAttachment(); if (excludeSet?.has(clientId)) { continue; } if (includeSet && !includeSet.has(clientId)) { continue; } const clientInfo = await this.getClientInfo(clientId); results.push({ socket, clientInfo }); } return results; } async render({ include, exclude, } = {}) { const sockets = await this.determineSockets({ include, exclude }); if (sockets.length === 0) return; await Promise.all(sockets.map(async ({ socket, clientInfo }) => { try { const url = new URL(clientInfo.url); url.searchParams.set("__rsc", "true"); const response = await fetch(url.toString(), { headers: { "Content-Type": "application/json", Cookie: clientInfo.cookieHeaders, }, }); if (!response.ok) { console.error(`Failed to fetch RSC update: ${response.statusText}`); return; } socket.send(new Uint8Array([MESSAGE_TYPE.RSC_START])); await this.streamResponse(response, socket, { chunk: MESSAGE_TYPE.RSC_CHUNK, end: MESSAGE_TYPE.RSC_END, }); } catch (err) { console.error("Failed to process socket:", err); } })); } async removeClientInfo(clientId) { this.clientInfoCache.delete(clientId); await this.storage.delete(`client:${clientId}`); } async webSocketClose(ws) { const clientId = ws.deserializeAttachment(); await this.removeClientInfo(clientId); } }