UNPKG

rwsdk

Version:

Build fast, server-driven webapps on Cloudflare with SSR, RSC, and realtime

201 lines (200 loc) 7.84 kB
import { DurableObject } from "cloudflare:workers"; import { MESSAGE_TYPE } from "./shared"; import { validateUpgradeRequest } from "./validateUpgradeRequest"; import { packMessage, unpackMessage, } from "./protocol"; export class RealtimeDurableObject extends DurableObject { constructor(state, env) { super(state, env); this.state = state; this.env = env; this.storage = state.storage; this.clientInfoCache = new Map(); } async fetch(request) { const validation = validateUpgradeRequest(request); if (!validation.valid) { return validation.response; } const url = new URL(request.url); const clientInfo = this.createClientInfo(url, request); return this.handleWebSocket(request, clientInfo); } createClientInfo(url, request) { return { url: url.searchParams.get("url"), clientId: url.searchParams.get("clientId"), cookieHeaders: request.headers.get("Cookie") || "", shouldForwardResponses: url.searchParams.get("shouldForwardResponses") === "true", }; } async storeClientInfo(clientInfo) { this.clientInfoCache.set(clientInfo.clientId, clientInfo); await this.storage.put(`client:${clientInfo.clientId}`, clientInfo); } async getClientInfo(clientId) { const cachedInfo = this.clientInfoCache.get(clientId); if (cachedInfo) { return cachedInfo; } const clientInfo = await this.storage.get(`client:${clientId}`); if (!clientInfo) { throw new Error(`Client info not found for clientId: ${clientId}`); } this.clientInfoCache.set(clientId, clientInfo); return clientInfo; } async handleWebSocket(request, clientInfo) { const { 0: client, 1: server } = new WebSocketPair(); await this.storeClientInfo(clientInfo); server.serializeAttachment(clientInfo.clientId); this.state.acceptWebSocket(server); return new Response(null, { status: 101, webSocket: client }); } async webSocketMessage(ws, data) { const clientId = ws.deserializeAttachment(); let clientInfo = await this.getClientInfo(clientId); const unpacked = unpackMessage(new Uint8Array(data)); if (unpacked.type === MESSAGE_TYPE.ACTION_REQUEST) { const message = unpacked; clientInfo = { ...clientInfo, url: message.clientUrl, }; await this.storeClientInfo(clientInfo); try { await this.handleAction(ws, message.id, message.args, clientInfo, message.requestId, message.clientUrl); } catch (error) { ws.send(packMessage({ type: MESSAGE_TYPE.ACTION_ERROR, id: message.requestId, error: error instanceof Error ? error.message : String(error), })); } } } async streamResponse(response, ws, messageTypes, streamId) { const startMessage = messageTypes.start === MESSAGE_TYPE.ACTION_START ? { type: MESSAGE_TYPE.ACTION_START, id: streamId, status: response.status, } : { type: MESSAGE_TYPE.RSC_START, id: streamId, status: response.status, }; ws.send(packMessage(startMessage)); const reader = response.body.getReader(); try { while (true) { const { done, value } = await reader.read(); if (done) { const endMessage = messageTypes.end === MESSAGE_TYPE.ACTION_END ? { type: MESSAGE_TYPE.ACTION_END, id: streamId } : { type: MESSAGE_TYPE.RSC_END, id: streamId }; ws.send(packMessage(endMessage)); break; } const chunkMessage = messageTypes.chunk === MESSAGE_TYPE.ACTION_CHUNK ? { type: MESSAGE_TYPE.ACTION_CHUNK, id: streamId, payload: value, } : { type: MESSAGE_TYPE.RSC_CHUNK, id: streamId, payload: value, }; ws.send(packMessage(chunkMessage)); } } finally { reader.releaseLock(); } } async handleAction(ws, id, args, clientInfo, requestId, clientUrl) { const url = new URL(clientUrl); url.searchParams.set("__rsc", ""); if (id != null) { url.searchParams.set("__rsc_action_id", id); } const response = await fetch(url.toString(), { method: "POST", body: args, headers: { "Content-Type": "application/json", Cookie: clientInfo.cookieHeaders, }, }); if (!response.ok && !clientInfo.shouldForwardResponses) { throw new Error(`Action failed: ${response.statusText}`); } this.render({ exclude: [clientInfo.clientId], }); await this.streamResponse(response, ws, { start: MESSAGE_TYPE.ACTION_START, chunk: MESSAGE_TYPE.ACTION_CHUNK, end: MESSAGE_TYPE.ACTION_END, }, requestId); } async determineSockets({ include = [], exclude = [], } = {}) { const sockets = Array.from(this.state.getWebSockets()); const includeSet = include.length > 0 ? new Set(include) : null; const excludeSet = exclude.length > 0 ? new Set(exclude) : null; const results = []; for (const socket of sockets) { const clientId = socket.deserializeAttachment(); if (excludeSet?.has(clientId)) { continue; } if (includeSet && !includeSet.has(clientId)) { continue; } const clientInfo = await this.getClientInfo(clientId); results.push({ socket, clientInfo }); } return results; } async render({ include, exclude, } = {}) { const sockets = await this.determineSockets({ include, exclude }); if (sockets.length === 0) return; await Promise.all(sockets.map(async ({ socket, clientInfo }) => { try { const url = new URL(clientInfo.url); url.searchParams.set("__rsc", "true"); const response = await fetch(url.toString(), { headers: { "Content-Type": "application/json", Cookie: clientInfo.cookieHeaders, }, }); if (!response.ok) { console.error(`Failed to fetch RSC update: ${response.statusText}`); return; } const rscId = crypto.randomUUID(); await this.streamResponse(response, socket, { start: MESSAGE_TYPE.RSC_START, chunk: MESSAGE_TYPE.RSC_CHUNK, end: MESSAGE_TYPE.RSC_END, }, rscId); } catch (err) { console.error("Failed to process socket:", err); } })); } async removeClientInfo(clientId) { this.clientInfoCache.delete(clientId); await this.storage.delete(`client:${clientId}`); } async webSocketClose(ws) { const clientId = ws.deserializeAttachment(); await this.removeClientInfo(clientId); } }