rwsdk
Version:
Build fast, server-driven webapps on Cloudflare with SSR, RSC, and realtime
201 lines (200 loc) • 7.84 kB
JavaScript
import { DurableObject } from "cloudflare:workers";
import { MESSAGE_TYPE } from "./shared";
import { validateUpgradeRequest } from "./validateUpgradeRequest";
import { packMessage, unpackMessage, } from "./protocol";
export class RealtimeDurableObject extends DurableObject {
constructor(state, env) {
super(state, env);
this.state = state;
this.env = env;
this.storage = state.storage;
this.clientInfoCache = new Map();
}
async fetch(request) {
const validation = validateUpgradeRequest(request);
if (!validation.valid) {
return validation.response;
}
const url = new URL(request.url);
const clientInfo = this.createClientInfo(url, request);
return this.handleWebSocket(request, clientInfo);
}
createClientInfo(url, request) {
return {
url: url.searchParams.get("url"),
clientId: url.searchParams.get("clientId"),
cookieHeaders: request.headers.get("Cookie") || "",
shouldForwardResponses: url.searchParams.get("shouldForwardResponses") === "true",
};
}
async storeClientInfo(clientInfo) {
this.clientInfoCache.set(clientInfo.clientId, clientInfo);
await this.storage.put(`client:${clientInfo.clientId}`, clientInfo);
}
async getClientInfo(clientId) {
const cachedInfo = this.clientInfoCache.get(clientId);
if (cachedInfo) {
return cachedInfo;
}
const clientInfo = await this.storage.get(`client:${clientId}`);
if (!clientInfo) {
throw new Error(`Client info not found for clientId: ${clientId}`);
}
this.clientInfoCache.set(clientId, clientInfo);
return clientInfo;
}
async handleWebSocket(request, clientInfo) {
const { 0: client, 1: server } = new WebSocketPair();
await this.storeClientInfo(clientInfo);
server.serializeAttachment(clientInfo.clientId);
this.state.acceptWebSocket(server);
return new Response(null, { status: 101, webSocket: client });
}
async webSocketMessage(ws, data) {
const clientId = ws.deserializeAttachment();
let clientInfo = await this.getClientInfo(clientId);
const unpacked = unpackMessage(new Uint8Array(data));
if (unpacked.type === MESSAGE_TYPE.ACTION_REQUEST) {
const message = unpacked;
clientInfo = {
...clientInfo,
url: message.clientUrl,
};
await this.storeClientInfo(clientInfo);
try {
await this.handleAction(ws, message.id, message.args, clientInfo, message.requestId, message.clientUrl);
}
catch (error) {
ws.send(packMessage({
type: MESSAGE_TYPE.ACTION_ERROR,
id: message.requestId,
error: error instanceof Error ? error.message : String(error),
}));
}
}
}
async streamResponse(response, ws, messageTypes, streamId) {
const startMessage = messageTypes.start === MESSAGE_TYPE.ACTION_START
? {
type: MESSAGE_TYPE.ACTION_START,
id: streamId,
status: response.status,
}
: {
type: MESSAGE_TYPE.RSC_START,
id: streamId,
status: response.status,
};
ws.send(packMessage(startMessage));
const reader = response.body.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
const endMessage = messageTypes.end === MESSAGE_TYPE.ACTION_END
? { type: MESSAGE_TYPE.ACTION_END, id: streamId }
: { type: MESSAGE_TYPE.RSC_END, id: streamId };
ws.send(packMessage(endMessage));
break;
}
const chunkMessage = messageTypes.chunk === MESSAGE_TYPE.ACTION_CHUNK
? {
type: MESSAGE_TYPE.ACTION_CHUNK,
id: streamId,
payload: value,
}
: {
type: MESSAGE_TYPE.RSC_CHUNK,
id: streamId,
payload: value,
};
ws.send(packMessage(chunkMessage));
}
}
finally {
reader.releaseLock();
}
}
async handleAction(ws, id, args, clientInfo, requestId, clientUrl) {
const url = new URL(clientUrl);
url.searchParams.set("__rsc", "");
if (id != null) {
url.searchParams.set("__rsc_action_id", id);
}
const response = await fetch(url.toString(), {
method: "POST",
body: args,
headers: {
"Content-Type": "application/json",
Cookie: clientInfo.cookieHeaders,
},
});
if (!response.ok && !clientInfo.shouldForwardResponses) {
throw new Error(`Action failed: ${response.statusText}`);
}
this.render({
exclude: [clientInfo.clientId],
});
await this.streamResponse(response, ws, {
start: MESSAGE_TYPE.ACTION_START,
chunk: MESSAGE_TYPE.ACTION_CHUNK,
end: MESSAGE_TYPE.ACTION_END,
}, requestId);
}
async determineSockets({ include = [], exclude = [], } = {}) {
const sockets = Array.from(this.state.getWebSockets());
const includeSet = include.length > 0 ? new Set(include) : null;
const excludeSet = exclude.length > 0 ? new Set(exclude) : null;
const results = [];
for (const socket of sockets) {
const clientId = socket.deserializeAttachment();
if (excludeSet?.has(clientId)) {
continue;
}
if (includeSet && !includeSet.has(clientId)) {
continue;
}
const clientInfo = await this.getClientInfo(clientId);
results.push({ socket, clientInfo });
}
return results;
}
async render({ include, exclude, } = {}) {
const sockets = await this.determineSockets({ include, exclude });
if (sockets.length === 0)
return;
await Promise.all(sockets.map(async ({ socket, clientInfo }) => {
try {
const url = new URL(clientInfo.url);
url.searchParams.set("__rsc", "true");
const response = await fetch(url.toString(), {
headers: {
"Content-Type": "application/json",
Cookie: clientInfo.cookieHeaders,
},
});
if (!response.ok) {
console.error(`Failed to fetch RSC update: ${response.statusText}`);
return;
}
const rscId = crypto.randomUUID();
await this.streamResponse(response, socket, {
start: MESSAGE_TYPE.RSC_START,
chunk: MESSAGE_TYPE.RSC_CHUNK,
end: MESSAGE_TYPE.RSC_END,
}, rscId);
}
catch (err) {
console.error("Failed to process socket:", err);
}
}));
}
async removeClientInfo(clientId) {
this.clientInfoCache.delete(clientId);
await this.storage.delete(`client:${clientId}`);
}
async webSocketClose(ws) {
const clientId = ws.deserializeAttachment();
await this.removeClientInfo(clientId);
}
}