@rivetkit/cloudflare-workers
Version:
_Lightweight Libraries for Backends_
612 lines (602 loc) • 17.9 kB
JavaScript
// src/handler.ts
import { env as env2 } from "cloudflare:workers";
import { Hono } from "hono";
// src/actor-handler-do.ts
import { DurableObject, env } from "cloudflare:workers";
import {
createActorRouter,
createClientWithDriver,
createInlineClientDriver
} from "@rivetkit/core";
import { serializeEmptyPersistData } from "@rivetkit/core/driver-helpers";
// src/actor-driver.ts
import {
createGenericConnDrivers,
GenericConnGlobalState,
lookupInRegistry
} from "@rivetkit/core";
import invariant from "invariant";
var CloudflareDurableObjectGlobalState = class {
// Single map for all actor state
#dos = /* @__PURE__ */ new Map();
getDOState(actorId) {
const state = this.#dos.get(actorId);
invariant(state !== void 0, "durable object state not in global state");
return state;
}
setDOState(actorId, state) {
this.#dos.set(actorId, state);
}
};
var ActorHandler = class {
actor;
actorPromise = Promise.withResolvers();
genericConnGlobalState = new GenericConnGlobalState();
};
var CloudflareActorsActorDriver = class {
#registryConfig;
#runConfig;
#managerDriver;
#inlineClient;
#globalState;
#actors = /* @__PURE__ */ new Map();
constructor(registryConfig, runConfig, managerDriver, inlineClient, globalState) {
this.#registryConfig = registryConfig;
this.#runConfig = runConfig;
this.#managerDriver = managerDriver;
this.#inlineClient = inlineClient;
this.#globalState = globalState;
}
#getDOCtx(actorId) {
return this.#globalState.getDOState(actorId).ctx;
}
async loadActor(actorId) {
var _a;
let handler = this.#actors.get(actorId);
if (handler) {
if (handler.actorPromise) await handler.actorPromise.promise;
if (!handler.actor) throw new Error("Actor should be loaded");
return handler.actor;
}
handler = new ActorHandler();
this.#actors.set(actorId, handler);
const doState = this.#globalState.getDOState(actorId);
const storage = doState.ctx.storage;
const [name, key] = await Promise.all([
storage.get(KEYS.NAME),
storage.get(KEYS.KEY)
]);
if (!name) {
throw new Error(`Actor ${actorId} is not initialized - missing name`);
}
if (!key) {
throw new Error(`Actor ${actorId} is not initialized - missing key`);
}
const definition = lookupInRegistry(this.#registryConfig, name);
handler.actor = definition.instantiate();
const connDrivers = createGenericConnDrivers(
handler.genericConnGlobalState
);
await handler.actor.start(
connDrivers,
this,
this.#inlineClient,
actorId,
name,
key,
"unknown"
// TODO: Support regions in Cloudflare
);
(_a = handler.actorPromise) == null ? void 0 : _a.resolve();
handler.actorPromise = void 0;
return handler.actor;
}
getGenericConnGlobalState(actorId) {
const handler = this.#actors.get(actorId);
if (!handler) {
throw new Error(`Actor ${actorId} not loaded`);
}
return handler.genericConnGlobalState;
}
getContext(actorId) {
const state = this.#globalState.getDOState(actorId);
return { state: state.ctx };
}
async readPersistedData(actorId) {
return await this.#getDOCtx(actorId).storage.get(KEYS.PERSIST_DATA);
}
async writePersistedData(actorId, data) {
await this.#getDOCtx(actorId).storage.put(KEYS.PERSIST_DATA, data);
}
async setAlarm(actor, timestamp) {
await this.#getDOCtx(actor.id).storage.setAlarm(timestamp);
}
async getDatabase(actorId) {
return this.#getDOCtx(actorId).storage.sql;
}
};
function createCloudflareActorsActorDriverBuilder(globalState) {
return (registryConfig, runConfig, managerDriver, inlineClient) => {
return new CloudflareActorsActorDriver(
registryConfig,
runConfig,
managerDriver,
inlineClient,
globalState
);
};
}
// src/log.ts
import { getLogger } from "@rivetkit/core/log";
var LOGGER_NAME = "driver-cloudflare-workers";
function logger() {
return getLogger(LOGGER_NAME);
}
// src/actor-handler-do.ts
var KEYS = {
NAME: "rivetkit:name",
KEY: "rivetkit:key",
PERSIST_DATA: "rivetkit:data"
};
function createActorDurableObject(registry, runConfig) {
const globalState = new CloudflareDurableObjectGlobalState();
return class ActorHandler extends DurableObject {
#initialized;
#initializedPromise;
#actor;
async #loadActor() {
if (!this.#initialized) {
if (this.#initializedPromise) {
await this.#initializedPromise.promise;
} else {
this.#initializedPromise = Promise.withResolvers();
const res = await this.ctx.storage.get([
KEYS.NAME,
KEYS.KEY,
KEYS.PERSIST_DATA
]);
if (res.get(KEYS.PERSIST_DATA)) {
const name = res.get(KEYS.NAME);
if (!name) throw new Error("missing actor name");
const key = res.get(KEYS.KEY);
if (!key) throw new Error("missing actor key");
logger().debug("already initialized", { name, key });
this.#initialized = { name, key };
this.#initializedPromise.resolve();
} else {
logger().debug("waiting to initialize");
}
}
}
if (this.#actor) {
return this.#actor;
}
if (!this.#initialized) throw new Error("Not initialized");
const actorId = this.ctx.id.toString();
globalState.setDOState(actorId, { ctx: this.ctx, env });
runConfig.driver.actor = createCloudflareActorsActorDriverBuilder(globalState);
const managerDriver = runConfig.driver.manager(
registry.config,
runConfig
);
const inlineClient = createClientWithDriver(
createInlineClientDriver(managerDriver)
);
const actorDriver = runConfig.driver.actor(
registry.config,
runConfig,
managerDriver,
inlineClient
);
const actorRouter = createActorRouter(runConfig, actorDriver);
this.#actor = {
actorRouter
};
await actorDriver.loadActor(actorId);
return this.#actor;
}
/** RPC called by the service that creates the DO to initialize it. */
async initialize(req) {
await this.ctx.storage.put({
[KEYS.NAME]: req.name,
[KEYS.KEY]: req.key,
[KEYS.PERSIST_DATA]: serializeEmptyPersistData(req.input)
});
this.#initialized = {
name: req.name,
key: req.key
};
logger().debug("initialized actor", { key: req.key });
await this.#loadActor();
}
async fetch(request) {
const { actorRouter } = await this.#loadActor();
const actorId = this.ctx.id.toString();
return await actorRouter.fetch(request, {
actorId
});
}
async alarm() {
await this.#loadActor();
const actorId = this.ctx.id.toString();
const managerDriver = runConfig.driver.manager(
registry.config,
runConfig
);
const inlineClient = createClientWithDriver(
createInlineClientDriver(managerDriver)
);
const actorDriver = runConfig.driver.actor(
registry.config,
runConfig,
managerDriver,
inlineClient
);
const actor = await actorDriver.loadActor(actorId);
await actor.onAlarm();
}
};
}
// src/config.ts
import { RunConfigSchema } from "@rivetkit/core/driver-helpers";
import { z } from "zod";
var ConfigSchema = RunConfigSchema.removeDefault().omit({ driver: true, getUpgradeWebSocket: true }).extend({
app: z.custom().optional()
}).default({});
// src/manager-driver.ts
import {
HEADER_AUTH_DATA,
HEADER_CONN_PARAMS,
HEADER_ENCODING,
HEADER_EXPOSE_INTERNAL_ERROR
} from "@rivetkit/core/driver-helpers";
import { ActorAlreadyExists, InternalError } from "@rivetkit/core/errors";
// src/util.ts
var EMPTY_KEY = "(none)";
var KEY_SEPARATOR = ",";
function serializeNameAndKey(name, key) {
const escapedName = name.replace(/:/g, "\\:");
if (key.length === 0) {
return `${escapedName}:${EMPTY_KEY}`;
}
const serializedKey = serializeKey(key);
return `${escapedName}:${serializedKey}`;
}
function serializeKey(key) {
if (key.length === 0) {
return EMPTY_KEY;
}
const escapedParts = key.map((part) => {
if (part === EMPTY_KEY) {
return `\\${EMPTY_KEY}`;
}
let escaped = part.replace(/\\/g, "\\\\");
escaped = escaped.replace(/,/g, "\\,");
return escaped;
});
return escapedParts.join(KEY_SEPARATOR);
}
// src/manager-driver.ts
var KEYS2 = {
ACTOR: {
// Combined key for actor metadata (name and key)
metadata: (actorId) => `actor:${actorId}:metadata`,
// Key index function for actor lookup
keyIndex: (name, key = []) => {
return `actor_key:${serializeKey(key)}`;
}
}
};
var STANDARD_WEBSOCKET_HEADERS = [
"connection",
"upgrade",
"sec-websocket-key",
"sec-websocket-version",
"sec-websocket-protocol",
"sec-websocket-extensions"
];
var CloudflareActorsManagerDriver = class {
async sendRequest(actorId, actorRequest) {
const env3 = getCloudflareAmbientEnv();
logger().debug("sending request to durable object", {
actorId,
method: actorRequest.method,
url: actorRequest.url
});
const id = env3.ACTOR_DO.idFromString(actorId);
const stub = env3.ACTOR_DO.get(id);
return await stub.fetch(actorRequest);
}
async openWebSocket(path, actorId, encoding, params) {
const env3 = getCloudflareAmbientEnv();
logger().debug("opening websocket to durable object", { actorId, path });
const id = env3.ACTOR_DO.idFromString(actorId);
const stub = env3.ACTOR_DO.get(id);
const headers = {
Upgrade: "websocket",
Connection: "Upgrade",
[HEADER_EXPOSE_INTERNAL_ERROR]: "true",
[HEADER_ENCODING]: encoding
};
if (params) {
headers[HEADER_CONN_PARAMS] = JSON.stringify(params);
}
headers["sec-websocket-protocol"] = "rivetkit";
const url = `http://actor${path}`;
logger().debug("rewriting websocket url", {
from: path,
to: url
});
const response = await stub.fetch(url, {
headers
});
const webSocket = response.webSocket;
if (!webSocket) {
throw new InternalError(
"missing websocket connection in response from DO"
);
}
logger().debug("durable object websocket connection open", {
actorId
});
webSocket.accept();
setTimeout(() => {
var _a;
const event = new Event("open");
(_a = webSocket.onopen) == null ? void 0 : _a.call(webSocket, event);
webSocket.dispatchEvent(event);
}, 0);
return webSocket;
}
async proxyRequest(c, actorRequest, actorId) {
logger().debug("forwarding request to durable object", {
actorId,
method: actorRequest.method,
url: actorRequest.url
});
const id = c.env.ACTOR_DO.idFromString(actorId);
const stub = c.env.ACTOR_DO.get(id);
return await stub.fetch(actorRequest);
}
async proxyWebSocket(c, path, actorId, encoding, params, authData) {
logger().debug("forwarding websocket to durable object", {
actorId,
path
});
const upgradeHeader = c.req.header("Upgrade");
if (!upgradeHeader || upgradeHeader !== "websocket") {
return new Response("Expected Upgrade: websocket", {
status: 426
});
}
const newUrl = new URL(`http://actor${path}`);
const actorRequest = new Request(newUrl, c.req.raw);
logger().debug("rewriting websocket url", {
from: c.req.url,
to: actorRequest.url
});
const headerKeys = [];
actorRequest.headers.forEach((v, k) => headerKeys.push(k));
for (const k of headerKeys) {
if (!STANDARD_WEBSOCKET_HEADERS.includes(k)) {
actorRequest.headers.delete(k);
}
}
actorRequest.headers.set(HEADER_EXPOSE_INTERNAL_ERROR, "true");
actorRequest.headers.set(HEADER_ENCODING, encoding);
if (params) {
actorRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(params));
}
if (authData) {
actorRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData));
}
const id = c.env.ACTOR_DO.idFromString(actorId);
const stub = c.env.ACTOR_DO.get(id);
return await stub.fetch(actorRequest);
}
async getForId({
c,
actorId
}) {
const env3 = getCloudflareAmbientEnv();
const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), {
type: "json"
});
if (!actorData) {
return void 0;
}
return {
actorId,
name: actorData.name,
key: actorData.key
};
}
async getWithKey({
c,
name,
key
}) {
const env3 = getCloudflareAmbientEnv();
logger().debug("getWithKey: searching for actor", { name, key });
const nameKeyString = serializeNameAndKey(name, key);
const actorId = env3.ACTOR_DO.idFromName(nameKeyString).toString();
const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), {
type: "json"
});
if (!actorData) {
logger().debug("getWithKey: no actor found with matching name and key", {
name,
key,
actorId
});
return void 0;
}
logger().debug("getWithKey: found actor with matching name and key", {
actorId,
name,
key
});
return this.#buildActorOutput(c, actorId);
}
async getOrCreateWithKey(input) {
const getOutput = await this.getWithKey(input);
if (getOutput) {
return getOutput;
} else {
return await this.createActor(input);
}
}
async createActor({
c,
name,
key,
input
}) {
const env3 = getCloudflareAmbientEnv();
const existingActor = await this.getWithKey({ c, name, key });
if (existingActor) {
throw new ActorAlreadyExists(name, key);
}
const nameKeyString = serializeNameAndKey(name, key);
const doId = env3.ACTOR_DO.idFromName(nameKeyString);
const actorId = doId.toString();
const actor = env3.ACTOR_DO.get(doId);
await actor.initialize({
name,
key,
input
});
const actorData = { name, key };
await env3.ACTOR_KV.put(
KEYS2.ACTOR.metadata(actorId),
JSON.stringify(actorData)
);
await env3.ACTOR_KV.put(KEYS2.ACTOR.keyIndex(name, key), actorId);
return {
actorId,
name,
key
};
}
// Helper method to build actor output from an ID
async #buildActorOutput(c, actorId) {
const env3 = getCloudflareAmbientEnv();
const actorData = await env3.ACTOR_KV.get(KEYS2.ACTOR.metadata(actorId), {
type: "json"
});
if (!actorData) {
return void 0;
}
return {
actorId,
name: actorData.name,
key: actorData.key
};
}
};
// src/websocket.ts
import { defineWebSocketHelper, WSContext } from "hono/ws";
var upgradeWebSocket = defineWebSocketHelper(async (c, events) => {
var _a, _b;
const upgradeHeader = c.req.header("Upgrade");
if (upgradeHeader !== "websocket") {
return;
}
const webSocketPair = new WebSocketPair();
const client = webSocketPair[0];
const server = webSocketPair[1];
const wsContext = new WSContext({
close: (code, reason) => server.close(code, reason),
get protocol() {
return server.protocol;
},
raw: server,
get readyState() {
return server.readyState;
},
url: server.url ? new URL(server.url) : null,
send: (source) => server.send(source)
});
if (events.onClose) {
server.addEventListener(
"close",
(evt) => {
var _a2;
return (_a2 = events.onClose) == null ? void 0 : _a2.call(events, evt, wsContext);
}
);
}
if (events.onMessage) {
server.addEventListener(
"message",
(evt) => {
var _a2;
return (_a2 = events.onMessage) == null ? void 0 : _a2.call(events, evt, wsContext);
}
);
}
if (events.onError) {
server.addEventListener(
"error",
(evt) => {
var _a2;
return (_a2 = events.onError) == null ? void 0 : _a2.call(events, evt, wsContext);
}
);
}
(_a = server.accept) == null ? void 0 : _a.call(server);
(_b = events.onOpen) == null ? void 0 : _b.call(events, new Event("open"), wsContext);
return new Response(null, {
status: 101,
headers: {
// HACK: Required in order for Cloudflare to not error with "Network connection lost"
//
// This bug undocumented. Cannot easily reproduce outside of RivetKit.
"Sec-WebSocket-Protocol": "rivetkit"
},
webSocket: client
});
});
// src/handler.ts
function getCloudflareAmbientEnv() {
return env2;
}
function createServerHandler(registry, inputConfig) {
const { createHandler } = createServer(registry, inputConfig);
return createHandler();
}
function createServer(registry, inputConfig) {
const config = ConfigSchema.parse(inputConfig);
const runConfig = {
driver: {
name: "cloudflare-workers",
manager: () => new CloudflareActorsManagerDriver(),
// HACK: We can't build the actor driver until we're inside the Durable Object
actor: void 0
},
getUpgradeWebSocket: () => upgradeWebSocket,
...config
};
const ActorHandler2 = createActorDurableObject(registry, runConfig);
const serverOutput = registry.createServer(runConfig);
return {
client: serverOutput.client,
createHandler: (hono) => {
const app = hono ?? new Hono();
if (!hono) {
app.route("/registry", serverOutput.hono);
}
const handler = {
fetch: (request, env3, ctx) => {
return app.fetch(request, env3, ctx);
}
};
return { handler, ActorHandler: ActorHandler2 };
}
};
}
export {
createServer,
createServerHandler
};
//# sourceMappingURL=mod.js.map