UNPKG

@rivetkit/cloudflare-workers

Version:

_Lightweight Libraries for Backends_

612 lines (602 loc) 17.9 kB
// src/handler.ts import { env as env2 } from "cloudflare:workers"; import { Hono } from "hono"; // src/actor-handler-do.ts import { DurableObject, env } from "cloudflare:workers"; import { createActorRouter, createClientWithDriver, createInlineClientDriver } from "@rivetkit/core"; import { serializeEmptyPersistData } from "@rivetkit/core/driver-helpers"; // src/actor-driver.ts import { createGenericConnDrivers, GenericConnGlobalState, lookupInRegistry } from "@rivetkit/core"; import invariant from "invariant"; var CloudflareDurableObjectGlobalState = class { // Single map for all actor state #dos = /* @__PURE__ */ new Map(); getDOState(actorId) { const state = this.#dos.get(actorId); invariant(state !== void 0, "durable object state not in global state"); return state; } setDOState(actorId, state) { this.#dos.set(actorId, state); } }; var ActorHandler = class { actor; actorPromise = Promise.withResolvers(); genericConnGlobalState = new GenericConnGlobalState(); }; var CloudflareActorsActorDriver = class { #registryConfig; #runConfig; #managerDriver; #inlineClient; #globalState; #actors = /* @__PURE__ */ new Map(); constructor(registryConfig, runConfig, managerDriver, inlineClient, globalState) { this.#registryConfig = registryConfig; this.#runConfig = runConfig; this.#managerDriver = managerDriver; this.#inlineClient = inlineClient; this.#globalState = globalState; } #getDOCtx(actorId) { return this.#globalState.getDOState(actorId).ctx; } async loadActor(actorId) { var _a; let handler = this.#actors.get(actorId); if (handler) { if (handler.actorPromise) await handler.actorPromise.promise; if (!handler.actor) throw new Error("Actor should be loaded"); return handler.actor; } handler = new ActorHandler(); this.#actors.set(actorId, handler); const doState = this.#globalState.getDOState(actorId); const storage = doState.ctx.storage; const [name, key] = await Promise.all([ storage.get(KEYS.NAME), storage.get(KEYS.KEY) ]); if (!name) { throw new Error(`Actor ${actorId} is not initialized - missing name`); } if (!key) { throw new Error(`Actor ${actorId} is not initialized - missing key`); } const definition = lookupInRegistry(this.#registryConfig, name); handler.actor = definition.instantiate(); const connDrivers = createGenericConnDrivers( handler.genericConnGlobalState ); await handler.actor.start( connDrivers, this, this.#inlineClient, actorId, name, key, "unknown" // TODO: Support regions in Cloudflare ); (_a = handler.actorPromise) == null ? void 0 : _a.resolve(); handler.actorPromise = void 0; return handler.actor; } getGenericConnGlobalState(actorId) { const handler = this.#actors.get(actorId); if (!handler) { throw new Error(`Actor ${actorId} not loaded`); } return handler.genericConnGlobalState; } getContext(actorId) { const state = this.#globalState.getDOState(actorId); return { state: state.ctx }; } async readPersistedData(actorId) { return await this.#getDOCtx(actorId).storage.get(KEYS.PERSIST_DATA); } async writePersistedData(actorId, data) { await this.#getDOCtx(actorId).storage.put(KEYS.PERSIST_DATA, data); } async setAlarm(actor, timestamp) { await this.#getDOCtx(actor.id).storage.setAlarm(timestamp); } async getDatabase(actorId) { return this.#getDOCtx(actorId).storage.sql; } }; function createCloudflareActorsActorDriverBuilder(globalState) { return (registryConfig, runConfig, managerDriver, inlineClient) => { return new CloudflareActorsActorDriver( registryConfig, runConfig, managerDriver, inlineClient, globalState ); }; } // src/log.ts import { getLogger } from "@rivetkit/core/log"; var LOGGER_NAME = "driver-cloudflare-workers"; function logger() { return getLogger(LOGGER_NAME); } // src/actor-handler-do.ts var KEYS = { NAME: "rivetkit:name", KEY: "rivetkit:key", PERSIST_DATA: "rivetkit:data" }; function createActorDurableObject(registry, runConfig) { const globalState = new CloudflareDurableObjectGlobalState(); return class ActorHandler extends DurableObject { #initialized; #initializedPromise; #actor; async #loadActor() { if (!this.#initialized) { if (this.#initializedPromise) { await this.#initializedPromise.promise; } else { this.#initializedPromise = Promise.withResolvers(); const res = await this.ctx.storage.get([ KEYS.NAME, KEYS.KEY, KEYS.PERSIST_DATA ]); if (res.get(KEYS.PERSIST_DATA)) { const name = res.get(KEYS.NAME); if (!name) throw new Error("missing actor name"); const key = res.get(KEYS.KEY); if (!key) throw new Error("missing actor key"); logger().debug("already initialized", { name, key }); this.#initialized = { name, key }; this.#initializedPromise.resolve(); } else { logger().debug("waiting to initialize"); } } } if (this.#actor) { return this.#actor; } if (!this.#initialized) throw new Error("Not initialized"); const actorId = this.ctx.id.toString(); globalState.setDOState(actorId, { ctx: this.ctx, env }); runConfig.driver.actor = createCloudflareActorsActorDriverBuilder(globalState); const managerDriver = runConfig.driver.manager( registry.config, runConfig ); const inlineClient = createClientWithDriver( createInlineClientDriver(managerDriver) ); const actorDriver = runConfig.driver.actor( registry.config, runConfig, managerDriver, inlineClient ); const actorRouter = createActorRouter(runConfig, actorDriver); this.#actor = { actorRouter }; await actorDriver.loadActor(actorId); return this.#actor; } /** RPC called by the service that creates the DO to initialize it. */ async initialize(req) { await this.ctx.storage.put({ [KEYS.NAME]: req.name, [KEYS.KEY]: req.key, [KEYS.PERSIST_DATA]: serializeEmptyPersistData(req.input) }); this.#initialized = { name: req.name, key: req.key }; logger().debug("initialized actor", { key: req.key }); await this.#loadActor(); } async fetch(request) { const { actorRouter } = await this.#loadActor(); const actorId = this.ctx.id.toString(); return await actorRouter.fetch(request, { actorId }); } async alarm() { await this.#loadActor(); const actorId = this.ctx.id.toString(); const managerDriver = runConfig.driver.manager( registry.config, runConfig ); const inlineClient = createClientWithDriver( createInlineClientDriver(managerDriver) ); const actorDriver = runConfig.driver.actor( registry.config, runConfig, managerDriver, inlineClient ); const actor = await actorDriver.loadActor(actorId); await actor.onAlarm(); } }; } // src/config.ts import { RunConfigSchema } from "@rivetkit/core/driver-helpers"; import { z } from "zod"; var ConfigSchema = RunConfigSchema.removeDefault().omit({ driver: true, getUpgradeWebSocket: true }).extend({ app: z.custom().optional() }).default({}); // src/manager-driver.ts import { HEADER_AUTH_DATA, HEADER_CONN_PARAMS, HEADER_ENCODING, HEADER_EXPOSE_INTERNAL_ERROR } from "@rivetkit/core/driver-helpers"; import { ActorAlreadyExists, InternalError } from "@rivetkit/core/errors"; // src/util.ts var EMPTY_KEY = "(none)"; var KEY_SEPARATOR = ","; function serializeNameAndKey(name, key) { const escapedName = name.replace(/:/g, "\\:"); if (key.length === 0) { return `${escapedName}:${EMPTY_KEY}`; } const serializedKey = serializeKey(key); return `${escapedName}:${serializedKey}`; } function serializeKey(key) { if (key.length === 0) { return EMPTY_KEY; } const escapedParts = key.map((part) => { if (part === EMPTY_KEY) { return `\\${EMPTY_KEY}`; } let escaped = part.replace(/\\/g, "\\\\"); escaped = escaped.replace(/,/g, "\\,"); return escaped; }); return escapedParts.join(KEY_SEPARATOR); } // src/manager-driver.ts var KEYS2 = { ACTOR: { // Combined key for actor metadata (name and key) metadata: (actorId) => `actor:${actorId}:metadata`, // Key index function for actor lookup keyIndex: (name, key = []) => { return `actor_key:${serializeKey(key)}`; } } }; var STANDARD_WEBSOCKET_HEADERS = [ "connection", "upgrade", "sec-websocket-key", "sec-websocket-version", "sec-websocket-protocol", "sec-websocket-extensions" ]; var CloudflareActorsManagerDriver = class { async sendRequest(actorId, actorRequest) { const env3 = getCloudflareAmbientEnv(); logger().debug("sending request to durable object", { actorId, method: actorRequest.method, url: actorRequest.url }); const id = env3.ACTOR_DO.idFromString(actorId); const stub = env3.ACTOR_DO.get(id); return await stub.fetch(actorRequest); } async openWebSocket(path, actorId, encoding, params) { const env3 = getCloudflareAmbientEnv(); logger().debug("opening websocket to durable object", { actorId, path }); const id = env3.ACTOR_DO.idFromString(actorId); const stub = env3.ACTOR_DO.get(id); const headers = { Upgrade: "websocket", Connection: "Upgrade", [HEADER_EXPOSE_INTERNAL_ERROR]: "true", [HEADER_ENCODING]: encoding }; if (params) { headers[HEADER_CONN_PARAMS] = JSON.stringify(params); } headers["sec-websocket-protocol"] = "rivetkit"; const url = `http://actor${path}`; logger().debug("rewriting websocket url", { from: path, to: url }); const response = await stub.fetch(url, { headers }); const webSocket = response.webSocket; if (!webSocket) { throw new InternalError( "missing websocket connection in response from DO" ); } logger().debug("durable object websocket connection open", { actorId }); webSocket.accept(); setTimeout(() => { var _a; const event = new Event("open"); (_a = webSocket.onopen) == null ? void 0 : _a.call(webSocket, event); webSocket.dispatchEvent(event); }, 0); return webSocket; } async proxyRequest(c, actorRequest, actorId) { logger().debug("forwarding request to durable object", { actorId, method: actorRequest.method, url: actorRequest.url }); const id = c.env.ACTOR_DO.idFromString(actorId); const stub = c.env.ACTOR_DO.get(id); return await stub.fetch(actorRequest); } async proxyWebSocket(c, path, actorId, encoding, params, authData) { logger().debug("forwarding websocket to durable object", { actorId, path }); const upgradeHeader = c.req.header("Upgrade"); if (!upgradeHeader || upgradeHeader !== "websocket") { return new Response("Expected Upgrade: websocket", { status: 426 }); } const newUrl = new URL(`http://actor${path}`); const actorRequest = new Request(newUrl, c.req.raw); logger().debug("rewriting websocket url", { from: c.req.url, to: actorRequest.url }); const headerKeys = []; actorRequest.headers.forEach((v, k) => headerKeys.push(k)); for (const k of headerKeys) { if (!STANDARD_WEBSOCKET_HEADERS.includes(k)) { actorRequest.headers.delete(k); } } actorRequest.headers.set(HEADER_EXPOSE_INTERNAL_ERROR, "true"); actorRequest.headers.set(HEADER_ENCODING, encoding); if (params) { actorRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); } if (authData) { actorRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData)); } const id = c.env.ACTOR_DO.idFromString(actorId); const stub = c.env.ACTOR_DO.get(id); return await stub.fetch(actorRequest); } async getForId({ c, actorId }) { const env3 = getCloudflareAmbientEnv(); const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), { type: "json" }); if (!actorData) { return void 0; } return { actorId, name: actorData.name, key: actorData.key }; } async getWithKey({ c, name, key }) { const env3 = getCloudflareAmbientEnv(); logger().debug("getWithKey: searching for actor", { name, key }); const nameKeyString = serializeNameAndKey(name, key); const actorId = env3.ACTOR_DO.idFromName(nameKeyString).toString(); const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), { type: "json" }); if (!actorData) { logger().debug("getWithKey: no actor found with matching name and key", { name, key, actorId }); return void 0; } logger().debug("getWithKey: found actor with matching name and key", { actorId, name, key }); return this.#buildActorOutput(c, actorId); } async getOrCreateWithKey(input) { const getOutput = await this.getWithKey(input); if (getOutput) { return getOutput; } else { return await this.createActor(input); } } async createActor({ c, name, key, input }) { const env3 = getCloudflareAmbientEnv(); const existingActor = await this.getWithKey({ c, name, key }); if (existingActor) { throw new ActorAlreadyExists(name, key); } const nameKeyString = serializeNameAndKey(name, key); const doId = env3.ACTOR_DO.idFromName(nameKeyString); const actorId = doId.toString(); const actor = env3.ACTOR_DO.get(doId); await actor.initialize({ name, key, input }); const actorData = { name, key }; await env3.ACTOR_KV.put( KEYS2.ACTOR.metadata(actorId), JSON.stringify(actorData) ); await env3.ACTOR_KV.put(KEYS2.ACTOR.keyIndex(name, key), actorId); return { actorId, name, key }; } // Helper method to build actor output from an ID async #buildActorOutput(c, actorId) { const env3 = getCloudflareAmbientEnv(); const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), { type: "json" }); if (!actorData) { return void 0; } return { actorId, name: actorData.name, key: actorData.key }; } }; // src/websocket.ts import { defineWebSocketHelper, WSContext } from "hono/ws"; var upgradeWebSocket = defineWebSocketHelper(async (c, events) => { var _a, _b; const upgradeHeader = c.req.header("Upgrade"); if (upgradeHeader !== "websocket") { return; } const webSocketPair = new WebSocketPair(); const client = webSocketPair[0]; const server = webSocketPair[1]; const wsContext = new WSContext({ close: (code, reason) => server.close(code, reason), get protocol() { return server.protocol; }, raw: server, get readyState() { return server.readyState; }, url: server.url ? new URL(server.url) : null, send: (source) => server.send(source) }); if (events.onClose) { server.addEventListener( "close", (evt) => { var _a2; return (_a2 = events.onClose) == null ? void 0 : _a2.call(events, evt, wsContext); } ); } if (events.onMessage) { server.addEventListener( "message", (evt) => { var _a2; return (_a2 = events.onMessage) == null ? void 0 : _a2.call(events, evt, wsContext); } ); } if (events.onError) { server.addEventListener( "error", (evt) => { var _a2; return (_a2 = events.onError) == null ? void 0 : _a2.call(events, evt, wsContext); } ); } (_a = server.accept) == null ? void 0 : _a.call(server); (_b = events.onOpen) == null ? void 0 : _b.call(events, new Event("open"), wsContext); return new Response(null, { status: 101, headers: { // HACK: Required in order for Cloudflare to not error with "Network connection lost" // // This bug undocumented. Cannot easily reproduce outside of RivetKit. "Sec-WebSocket-Protocol": "rivetkit" }, webSocket: client }); }); // src/handler.ts function getCloudflareAmbientEnv() { return env2; } function createServerHandler(registry, inputConfig) { const { createHandler } = createServer(registry, inputConfig); return createHandler(); } function createServer(registry, inputConfig) { const config = ConfigSchema.parse(inputConfig); const runConfig = { driver: { name: "cloudflare-workers", manager: () => new CloudflareActorsManagerDriver(), // HACK: We can't build the actor driver until we're inside the Durable Object actor: void 0 }, getUpgradeWebSocket: () => upgradeWebSocket, ...config }; const ActorHandler2 = createActorDurableObject(registry, runConfig); const serverOutput = registry.createServer(runConfig); return { client: serverOutput.client, createHandler: (hono) => { const app = hono ?? new Hono(); if (!hono) { app.route("/registry", serverOutput.hono); } const handler = { fetch: (request, env3, ctx) => { return app.fetch(request, env3, ctx); } }; return { handler, ActorHandler: ActorHandler2 }; } }; } export { createServer, createServerHandler }; //# sourceMappingURL=mod.js.map