UNPKG

partyserver

Version:

Build real-time applications powered by [Durable Objects](https://developers.cloudflare.com/durable-objects/), inspired by [PartyKit](https://www.partykit.io/).

619 lines (616 loc) 19.1 kB
// src/index.ts import { DurableObject } from "cloudflare:workers"; import { nanoid } from "nanoid"; // src/connection.ts if (!("OPEN" in WebSocket)) { const WebSocketStatus = { // @ts-expect-error CONNECTING: WebSocket.READY_STATE_CONNECTING, // @ts-expect-error OPEN: WebSocket.READY_STATE_OPEN, // @ts-expect-error CLOSING: WebSocket.READY_STATE_CLOSING, // @ts-expect-error CLOSED: WebSocket.READY_STATE_CLOSED }; Object.assign(WebSocket, WebSocketStatus); Object.assign(WebSocket.prototype, WebSocketStatus); } var AttachmentCache = class { #cache = /* @__PURE__ */ new WeakMap(); get(ws) { let attachment = this.#cache.get(ws); if (!attachment) { attachment = WebSocket.prototype.deserializeAttachment.call( ws ); if (attachment !== void 0) { this.#cache.set(ws, attachment); } else { throw new Error( "Missing websocket attachment. This is most likely an issue in PartyServer, please open an issue at https://github.com/threepointone/partyserver/issues" ); } } return attachment; } set(ws, attachment) { this.#cache.set(ws, attachment); WebSocket.prototype.serializeAttachment.call(ws, attachment); } }; var attachments = new AttachmentCache(); var connections = /* @__PURE__ */ new WeakSet(); var isWrapped = (ws) => { return connections.has(ws); }; var createLazyConnection = (ws) => { if (isWrapped(ws)) { return ws; } let initialState = void 0; if ("state" in ws) { initialState = ws.state; delete ws.state; } const connection = Object.defineProperties(ws, { id: { get() { return attachments.get(ws).__pk.id; } }, server: { get() { return attachments.get(ws).__pk.server; } }, socket: { get() { return ws; } }, state: { get() { return ws.deserializeAttachment(); } }, setState: { value: function setState(setState) { let state; if (setState instanceof Function) { state = setState(this.state); } else { state = setState; } ws.serializeAttachment(state); return state; } }, deserializeAttachment: { value: function deserializeAttachment() { const attachment = attachments.get(ws); return attachment.__user ?? null; } }, serializeAttachment: { value: function serializeAttachment(attachment) { const setting = { ...attachments.get(ws), __user: attachment ?? null }; attachments.set(ws, setting); } } }); if (initialState) { connection.setState(initialState); } connections.add(connection); return connection; }; var HibernatingConnectionIterator = class { constructor(state, tag) { this.state = state; this.tag = tag; } index = 0; sockets; [Symbol.iterator]() { return this; } next() { const sockets = ( // biome-ignore lint/suspicious/noAssignInExpressions: <explanation> this.sockets ?? (this.sockets = this.state.getWebSockets(this.tag)) ); let socket; while (socket = sockets[this.index++]) { if (socket.readyState === WebSocket.READY_STATE_OPEN) { const value = createLazyConnection(socket); return { done: false, value }; } } return { done: true, value: void 0 }; } }; var InMemoryConnectionManager = class { #connections = /* @__PURE__ */ new Map(); tags = /* @__PURE__ */ new WeakMap(); getCount() { return this.#connections.size; } getConnection(id) { return this.#connections.get(id); } *getConnections(tag) { if (!tag) { yield* this.#connections.values().filter( (c) => c.readyState === WebSocket.READY_STATE_OPEN ); return; } for (const connection of this.#connections.values()) { const connectionTags = this.tags.get(connection) ?? []; if (connectionTags.includes(tag)) { yield connection; } } } accept(connection, options) { connection.accept(); this.#connections.set(connection.id, connection); this.tags.set(connection, [ // make sure we have id tag connection.id, ...options.tags.filter((t) => t !== connection.id) ]); const removeConnection = () => { this.#connections.delete(connection.id); connection.removeEventListener("close", removeConnection); connection.removeEventListener("error", removeConnection); }; connection.addEventListener("close", removeConnection); connection.addEventListener("error", removeConnection); return connection; } }; var HibernatingConnectionManager = class { constructor(controller) { this.controller = controller; } getCount() { return Number(this.controller.getWebSockets().length); } getConnection(id) { const sockets = this.controller.getWebSockets(id); if (sockets.length === 0) return void 0; if (sockets.length === 1) return createLazyConnection(sockets[0]); throw new Error( `More than one connection found for id ${id}. Did you mean to use getConnections(tag) instead?` ); } getConnections(tag) { return new HibernatingConnectionIterator(this.controller, tag); } accept(connection, options) { const tags = [ connection.id, ...options.tags.filter((t) => t !== connection.id) ]; if (tags.length > 10) { throw new Error( "A connection can only have 10 tags, including the default id tag." ); } for (const tag of tags) { if (typeof tag !== "string") { throw new Error(`A connection tag must be a string. Received: ${tag}`); } if (tag === "") { throw new Error("A connection tag must not be an empty string."); } if (tag.length > 256) { throw new Error("A connection tag must not exceed 256 characters"); } } this.controller.acceptWebSocket(connection, tags); connection.serializeAttachment({ __pk: { id: connection.id, server: options.server }, __user: null }); return createLazyConnection(connection); } }; // src/index.ts var serverMapCache = /* @__PURE__ */ new WeakMap(); async function getServerByName(serverNamespace, name, options) { if (options?.jurisdiction) { serverNamespace = serverNamespace.jurisdiction(options.jurisdiction); } const id = serverNamespace.idFromName(name); const stub = serverNamespace.get(id, options); const req = new Request( "http://dummy-example.cloudflare.com/cdn-cgi/partyserver/set-name/" ); req.headers.set("x-partykit-room", name); stub.fetch(req).catch((e) => { console.error("Could not set server name:", e); }); return stub; } function camelCaseToKebabCase(str) { if (str === str.toUpperCase() && str !== str.toLowerCase()) { return str.toLowerCase().replace(/_/g, "-"); } let kebabified = str.replace( /[A-Z]/g, (letter) => `-${letter.toLowerCase()}` ); kebabified = kebabified.startsWith("-") ? kebabified.slice(1) : kebabified; return kebabified.replace(/_/g, "-").replace(/-$/, ""); } async function routePartykitRequest(req, env, options) { if (!serverMapCache.has(env)) { serverMapCache.set( env, Object.entries(env).reduce((acc, [k, v]) => { if (v && typeof v === "object" && "idFromName" in v && typeof v.idFromName === "function") { Object.assign(acc, { [camelCaseToKebabCase(k)]: v }); return acc; } return acc; }, {}) ); } const map = serverMapCache.get(env); const prefix = options?.prefix || "parties"; const prefixParts = prefix.split("/"); const url = new URL(req.url); const parts = url.pathname.split("/").filter(Boolean); const prefixMatches = prefixParts.every( (part, index) => parts[index] === part ); if (!prefixMatches || parts.length < prefixParts.length + 2) { return null; } const namespace = parts[prefixParts.length]; const name = parts[prefixParts.length + 1]; if (name && namespace) { if (!map[namespace]) { if (namespace === "main") { console.warn( "You appear to be migrating a PartyKit project to PartyServer." ); console.warn(`PartyServer doesn't have a "main" party by default. Try adding this to your PartySocket client: party: "${camelCaseToKebabCase(Object.keys(map)[0])}"`); } else { console.error(`The url ${req.url} does not match any server namespace. Did you forget to add a durable object binding to the class in your wrangler.toml?`); } } let doNamespace = map[namespace]; if (options?.jurisdiction) { doNamespace = doNamespace.jurisdiction(options.jurisdiction); } const id = doNamespace.idFromName(name); const stub = doNamespace.get(id, options); req = new Request(req); req.headers.set("x-partykit-room", name); req.headers.set("x-partykit-namespace", namespace); if (options?.jurisdiction) { req.headers.set("x-partykit-jurisdiction", options.jurisdiction); } if (req.headers.get("Upgrade")?.toLowerCase() === "websocket") { if (options?.onBeforeConnect) { const reqOrRes = await options.onBeforeConnect(req, { party: namespace, name }); if (reqOrRes instanceof Request) { req = reqOrRes; } else if (reqOrRes instanceof Response) { return reqOrRes; } } } else { if (options?.onBeforeRequest) { const reqOrRes = await options.onBeforeRequest(req, { party: namespace, name }); if (reqOrRes instanceof Request) { req = reqOrRes; } else if (reqOrRes instanceof Response) { return reqOrRes; } } } return stub.fetch(req); } else { return null; } } var Server = class extends DurableObject { static options = { hibernate: false }; #status = "zero"; #ParentClass = Object.getPrototypeOf(this).constructor; #connectionManager = this.#ParentClass.options.hibernate ? new HibernatingConnectionManager(this.ctx) : new InMemoryConnectionManager(); // biome-ignore lint/complexity/noUselessConstructor: <explanation> constructor(ctx, env) { super(ctx, env); } /** * Handle incoming requests to the server. */ async fetch(request) { if (!this.#_name) { const room = request.headers.get("x-partykit-room"); if ( // !namespace || !room ) { throw new Error(`Missing namespace or room headers when connecting to ${this.#ParentClass.name}. Did you try connecting directly to this Durable Object? Try using getServerByName(namespace, id) instead.`); } await this.setName(room); } try { const url = new URL(request.url); if (url.pathname === "/cdn-cgi/partyserver/set-name/") { return Response.json({ ok: true }); } if (request.headers.get("Upgrade")?.toLowerCase() !== "websocket") { return await this.onRequest(request); } else { const { 0: clientWebSocket, 1: serverWebSocket } = new WebSocketPair(); let connectionId = url.searchParams.get("_pk"); if (!connectionId) { connectionId = nanoid(); } let connection = Object.assign(serverWebSocket, { id: connectionId, server: this.name, state: null, setState(setState) { let state; if (setState instanceof Function) { state = setState(this.state); } else { state = setState; } this.state = state; return this.state; } }); const ctx = { request }; const tags = await this.getConnectionTags(connection, ctx); connection = this.#connectionManager.accept(connection, { tags, server: this.name }); if (!this.#ParentClass.options.hibernate) { this.#attachSocketEventHandlers(connection); } await this.onConnect(connection, ctx); return new Response(null, { status: 101, webSocket: clientWebSocket }); } } catch (err) { console.error( `Error in ${this.#ParentClass.name}:${this.name} fetch:`, err ); if (!(err instanceof Error)) throw err; if (request.headers.get("Upgrade") === "websocket") { const pair = new WebSocketPair(); pair[1].accept(); pair[1].send(JSON.stringify({ error: err.stack })); pair[1].close(1011, "Uncaught exception during session setup"); return new Response(null, { status: 101, webSocket: pair[0] }); } else { return new Response(err.stack, { status: 500 }); } } } async webSocketMessage(ws, message) { const connection = createLazyConnection(ws); await this.setName(connection.server); if (this.#status !== "started") { await this.#initialize(); } return this.onMessage(connection, message); } async webSocketClose(ws, code, reason, wasClean) { const connection = createLazyConnection(ws); await this.setName(connection.server); if (this.#status !== "started") { await this.#initialize(); } return this.onClose(connection, code, reason, wasClean); } async webSocketError(ws, error) { const connection = createLazyConnection(ws); await this.setName(connection.server); if (this.#status !== "started") { await this.#initialize(); } return this.onError(connection, error); } async #initialize() { await this.ctx.blockConcurrencyWhile(async () => { this.#status = "starting"; await this.onStart(); this.#status = "started"; }); } #attachSocketEventHandlers(connection) { const handleMessageFromClient = (event) => { this.onMessage(connection, event.data)?.catch((e) => { console.error("onMessage error:", e); }); }; const handleCloseFromClient = (event) => { connection.removeEventListener("message", handleMessageFromClient); connection.removeEventListener("close", handleCloseFromClient); this.onClose(connection, event.code, event.reason, event.wasClean)?.catch( (e) => { console.error("onClose error:", e); } ); }; const handleErrorFromClient = (e) => { connection.removeEventListener("message", handleMessageFromClient); connection.removeEventListener("error", handleErrorFromClient); this.onError(connection, e.error)?.catch((e2) => { console.error("onError error:", e2); }); }; connection.addEventListener("close", handleCloseFromClient); connection.addEventListener("error", handleErrorFromClient); connection.addEventListener("message", handleMessageFromClient); } // Public API #_name; #_longErrorAboutNameThrown = false; /** * The name for this server. Write-once-only. */ get name() { if (!this.#_name) { if (!this.#_longErrorAboutNameThrown) { this.#_longErrorAboutNameThrown = true; throw new Error( `Attempting to read .name on ${this.#ParentClass.name} before it was set. The name can be set by explicitly calling .setName(name) on the stub, or by using routePartyKitRequest(). This is a known issue and will be fixed soon. Follow https://github.com/cloudflare/workerd/issues/2240 for more updates.` ); } else { throw new Error( `Attempting to read .name on ${this.#ParentClass.name} before it was set.` ); } } return this.#_name; } // We won't have an await inside this function // but it will be called remotely, // so we need to mark it as async async setName(name) { if (!name) { throw new Error("A name is required."); } if (this.#_name && this.#_name !== name) { throw new Error("This server already has a name."); } this.#_name = name; if (this.#status !== "started") { await this.ctx.blockConcurrencyWhile(async () => { await this.#initialize(); }); } } #sendMessageToConnection(connection, message) { try { connection.send(message); } catch (_e) { connection.close(1011, "Unexpected error"); } } /** Send a message to all connected clients, except connection ids listed in `without` */ broadcast(msg, without) { for (const connection of this.#connectionManager.getConnections()) { if (!without || !without.includes(connection.id)) { this.#sendMessageToConnection(connection, msg); } } } /** Get a connection by connection id */ getConnection(id) { return this.#connectionManager.getConnection(id); } /** * Get all connections. Optionally, you can provide a tag to filter returned connections. * Use `Server#getConnectionTags` to tag the connection on connect. */ getConnections(tag) { return this.#connectionManager.getConnections(tag); } /** * You can tag a connection to filter them in Server#getConnections. * Each connection supports up to 9 tags, each tag max length is 256 characters. */ getConnectionTags(connection, context) { return []; } // Implemented by the user /** * Called when the server is started for the first time. */ onStart() { } /** * Called when a new connection is made to the server. */ onConnect(connection, ctx) { console.log( `Connection ${connection.id} connected to ${this.#ParentClass.name}:${this.name}` ); } /** * Called when a message is received from a connection. */ onMessage(connection, message) { console.log( `Received message on connection ${this.#ParentClass.name}:${connection.id}` ); console.info( `Implement onMessage on ${this.#ParentClass.name} to handle this message.` ); } /** * Called when a connection is closed. */ onClose(connection, code, reason, wasClean) { } /** * Called when an error occurs on a connection. */ onError(connection, error) { console.error( `Error on connection ${connection.id} in ${this.#ParentClass.name}:${this.name}:`, error ); console.info( `Implement onError on ${this.#ParentClass.name} to handle this error.` ); } /** * Called when a request is made to the server. */ onRequest(request) { console.warn( `onRequest hasn't been implemented on ${this.#ParentClass.name}:${this.name} responding to ${request.url}` ); return new Response("Not implemented", { status: 404 }); } onAlarm() { console.log( `Implement onAlarm on ${this.#ParentClass.name} to handle alarms.` ); } async alarm() { if (this.#status !== "started") { await this.#initialize(); } await this.onAlarm(); } }; export { Server, getServerByName, routePartykitRequest }; //# sourceMappingURL=index.js.map