rivetkit
Version:
Lightweight libraries for building stateful actors on edge platforms
1,712 lines (1,698 loc) • 123 kB
JavaScript
import {
configureInspectorAccessToken,
getInspectorUrl,
inspectorLogger,
isInspectorEnabled,
secureInspector
} from "./chunk-KL4V2ULR.js";
import {
ActorDefinition,
RemoteManagerDriver,
createActorInspectorRouter,
createClientWithDriver,
deserializeActorKey,
generateConnSocketId,
getEndpoint,
lookupInRegistry,
serializeActorKey
} from "./chunk-H7E2UU23.js";
import {
CreateActorSchema
} from "./chunk-KGDZYQYE.js";
import {
ActionContext,
HTTP_ACTION_REQUEST_VERSIONED,
HTTP_ACTION_RESPONSE_VERSIONED,
HTTP_RESPONSE_ERROR_VERSIONED,
RunnerConfigSchema,
TO_SERVER_VERSIONED,
createVersionedDataHandler,
parseMessage,
serializeEmptyPersistData
} from "./chunk-QRFXXTLG.js";
import {
EncodingSchema,
HEADER_ACTOR_ID,
HEADER_CONN_ID,
HEADER_CONN_PARAMS,
HEADER_CONN_TOKEN,
HEADER_ENCODING,
HEADER_RIVET_ACTOR,
HEADER_RIVET_TARGET,
PATH_CONNECT_WEBSOCKET,
PATH_RAW_WEBSOCKET_PREFIX,
WS_PROTOCOL_ACTOR,
WS_PROTOCOL_CONN_ID,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_CONN_TOKEN,
WS_PROTOCOL_ENCODING,
WS_PROTOCOL_PATH,
WS_PROTOCOL_TARGET,
WS_PROTOCOL_TRANSPORT,
contentTypeForEncoding,
deserializeWithEncoding,
encodingIsBinary,
generateRandomString,
loggerWithoutContext,
serializeWithEncoding
} from "./chunk-MLQIYKAZ.js";
import {
configureBaseLogger,
configureDefaultLogger,
getLogger
} from "./chunk-7E5K3375.js";
import {
VERSION,
assertUnreachable,
bufferToArrayBuffer,
deconstructError,
noopNext,
package_default,
promiseWithResolvers,
setLongTimeout,
stringifyError
} from "./chunk-HI55LHM3.js";
import {
ActorAlreadyExists,
ConnNotFound,
IncorrectConnToken,
InternalError,
InvalidEncoding,
InvalidParams,
InvalidRequest,
MissingActorHeader,
Unsupported,
UserError,
WebSocketsNotEnabled
} from "./chunk-YPZFLUO6.js";
// src/actor/config.ts
import { z } from "zod";
var ActorConfigSchema = z.object({
onCreate: z.function().optional(),
onStart: z.function().optional(),
onStop: z.function().optional(),
onStateChange: z.function().optional(),
onBeforeConnect: z.function().optional(),
onConnect: z.function().optional(),
onDisconnect: z.function().optional(),
onBeforeActionResponse: z.function().optional(),
onFetch: z.function().optional(),
onWebSocket: z.function().optional(),
actions: z.record(z.function()).default({}),
state: z.any().optional(),
createState: z.function().optional(),
connState: z.any().optional(),
createConnState: z.function().optional(),
vars: z.any().optional(),
db: z.any().optional(),
createVars: z.function().optional(),
options: z.object({
createVarsTimeout: z.number().positive().default(5e3),
createConnStateTimeout: z.number().positive().default(5e3),
onConnectTimeout: z.number().positive().default(5e3),
// This must be less than ACTOR_STOP_THRESHOLD_MS
onStopTimeout: z.number().positive().default(5e3),
stateSaveInterval: z.number().positive().default(1e4),
actionTimeout: z.number().positive().default(6e4),
// Max time to wait for waitUntil background promises during shutdown
waitUntilTimeout: z.number().positive().default(15e3),
connectionLivenessTimeout: z.number().positive().default(2500),
connectionLivenessInterval: z.number().positive().default(5e3),
noSleep: z.boolean().default(false),
sleepTimeout: z.number().positive().default(3e4)
}).strict().default({})
}).strict().refine(
(data) => !(data.state !== void 0 && data.createState !== void 0),
{
message: "Cannot define both 'state' and 'createState'",
path: ["state"]
}
).refine(
(data) => !(data.connState !== void 0 && data.createConnState !== void 0),
{
message: "Cannot define both 'connState' and 'createConnState'",
path: ["connState"]
}
).refine(
(data) => !(data.vars !== void 0 && data.createVars !== void 0),
{
message: "Cannot define both 'vars' and 'createVars'",
path: ["vars"]
}
);
// src/actor/router.ts
import { Hono } from "hono";
import { cors } from "hono/cors";
import invariant2 from "invariant";
// src/actor/router-endpoints.ts
import * as cbor from "cbor-x";
import { streamSSE } from "hono/streaming";
import invariant from "invariant";
// src/manager/log.ts
function logger() {
return getLogger("actor-manager");
}
// src/manager/hono-websocket-adapter.ts
var HonoWebSocketAdapter = class {
// WebSocket readyState values
CONNECTING = 0;
OPEN = 1;
CLOSING = 2;
CLOSED = 3;
#ws;
#readyState = 1;
// Start as OPEN since WSContext is already connected
#eventListeners = /* @__PURE__ */ new Map();
#closeCode;
#closeReason;
constructor(ws) {
this.#ws = ws;
this.#readyState = this.OPEN;
setTimeout(() => {
this.#fireEvent("open", { type: "open", target: this });
}, 0);
}
get readyState() {
return this.#readyState;
}
get binaryType() {
return "arraybuffer";
}
set binaryType(value) {
}
get bufferedAmount() {
return 0;
}
get extensions() {
return "";
}
get protocol() {
return "";
}
get url() {
return "";
}
send(data) {
if (this.readyState !== this.OPEN) {
throw new Error("WebSocket is not open");
}
try {
logger().debug({
msg: "bridge sending data",
dataType: typeof data,
isString: typeof data === "string",
isArrayBuffer: data instanceof ArrayBuffer,
dataStr: typeof data === "string" ? data.substring(0, 100) : "<non-string>"
});
if (typeof data === "string") {
this.#ws.send(data);
} else if (data instanceof ArrayBuffer) {
this.#ws.send(data);
} else if (ArrayBuffer.isView(data)) {
const buffer = data.buffer.slice(
data.byteOffset,
data.byteOffset + data.byteLength
);
if (buffer instanceof SharedArrayBuffer) {
const arrayBuffer = new ArrayBuffer(buffer.byteLength);
new Uint8Array(arrayBuffer).set(new Uint8Array(buffer));
this.#ws.send(arrayBuffer);
} else {
this.#ws.send(buffer);
}
} else if (data instanceof Blob) {
data.arrayBuffer().then((buffer) => {
this.#ws.send(buffer);
}).catch((error) => {
logger().error({
msg: "failed to convert blob to arraybuffer",
error
});
this.#fireEvent("error", { type: "error", target: this, error });
});
} else {
logger().warn({
msg: "unsupported data type, converting to string",
dataType: typeof data,
data
});
this.#ws.send(String(data));
}
} catch (error) {
logger().error({ msg: "error sending websocket data", error });
this.#fireEvent("error", { type: "error", target: this, error });
throw error;
}
}
close(code = 1e3, reason = "") {
if (this.readyState === this.CLOSING || this.readyState === this.CLOSED) {
return;
}
this.#readyState = this.CLOSING;
this.#closeCode = code;
this.#closeReason = reason;
try {
this.#ws.close(code, reason);
this.#readyState = this.CLOSED;
this.#fireEvent("close", {
type: "close",
target: this,
code,
reason,
wasClean: code === 1e3
});
} catch (error) {
logger().error({ msg: "error closing websocket", error });
this.#readyState = this.CLOSED;
this.#fireEvent("close", {
type: "close",
target: this,
code: 1006,
reason: "Abnormal closure",
wasClean: false
});
}
}
addEventListener(type, listener) {
if (!this.#eventListeners.has(type)) {
this.#eventListeners.set(type, /* @__PURE__ */ new Set());
}
this.#eventListeners.get(type).add(listener);
}
removeEventListener(type, listener) {
const listeners = this.#eventListeners.get(type);
if (listeners) {
listeners.delete(listener);
}
}
dispatchEvent(event) {
const listeners = this.#eventListeners.get(event.type);
if (listeners) {
for (const listener of listeners) {
try {
listener(event);
} catch (error) {
logger().error({
msg: `error in ${event.type} event listener`,
error
});
}
}
}
return true;
}
// Internal method to handle incoming messages from WSContext
_handleMessage(data) {
let messageData;
if (typeof data === "string") {
messageData = data;
} else if (data instanceof ArrayBuffer || ArrayBuffer.isView(data)) {
messageData = data;
} else if (data && typeof data === "object" && "data" in data) {
messageData = data.data;
} else {
messageData = String(data);
}
logger().debug({
msg: "bridge handling message",
dataType: typeof messageData,
isArrayBuffer: messageData instanceof ArrayBuffer,
dataStr: typeof messageData === "string" ? messageData : "<binary>"
});
this.#fireEvent("message", {
type: "message",
target: this,
data: messageData
});
}
// Internal method to handle close from WSContext
_handleClose(code, reason) {
this.#ws.close(1e3, "hack_force_close");
if (this.readyState === this.CLOSED) return;
this.#readyState = this.CLOSED;
this.#closeCode = code;
this.#closeReason = reason;
this.#fireEvent("close", {
type: "close",
target: this,
code,
reason,
wasClean: code === 1e3
});
}
// Internal method to handle errors from WSContext
_handleError(error) {
this.#fireEvent("error", {
type: "error",
target: this,
error
});
}
#fireEvent(type, event) {
const listeners = this.#eventListeners.get(type);
if (listeners) {
for (const listener of listeners) {
try {
listener(event);
} catch (error) {
logger().error({ msg: `error in ${type} event listener`, error });
}
}
}
switch (type) {
case "open":
if (this.#onopen) {
try {
this.#onopen(event);
} catch (error) {
logger().error({ msg: "error in onopen handler", error });
}
}
break;
case "close":
if (this.#onclose) {
try {
this.#onclose(event);
} catch (error) {
logger().error({ msg: "error in onclose handler", error });
}
}
break;
case "error":
if (this.#onerror) {
try {
this.#onerror(event);
} catch (error) {
logger().error({ msg: "error in onerror handler", error });
}
}
break;
case "message":
if (this.#onmessage) {
try {
this.#onmessage(event);
} catch (error) {
logger().error({ msg: "error in onmessage handler", error });
}
}
break;
}
}
// Event handler properties with getters/setters
#onopen = null;
#onclose = null;
#onerror = null;
#onmessage = null;
get onopen() {
return this.#onopen;
}
set onopen(handler) {
this.#onopen = handler;
}
get onclose() {
return this.#onclose;
}
set onclose(handler) {
this.#onclose = handler;
}
get onerror() {
return this.#onerror;
}
set onerror(handler) {
this.#onerror = handler;
}
get onmessage() {
return this.#onmessage;
}
set onmessage(handler) {
this.#onmessage = handler;
}
};
// src/actor/router-endpoints.ts
var SSE_PING_INTERVAL = 1e3;
async function handleWebSocketConnect(req, runConfig, actorDriver, actorId, encoding, parameters, connId, connToken) {
const exposeInternalError = req ? getRequestExposeInternalError(req) : false;
const {
promise: handlersPromise,
resolve: handlersResolve,
reject: handlersReject
} = promiseWithResolvers();
let actor2;
try {
actor2 = await actorDriver.loadActor(actorId);
} catch (error) {
return {
onOpen: (_evt, ws) => {
const { code } = deconstructError(
error,
actor2.rLog,
{
wsEvent: "open"
},
exposeInternalError
);
ws.close(1011, code);
},
onMessage: (_evt, ws) => {
ws.close(1011, "Actor not loaded");
},
onClose: (_event, _ws) => {
},
onError: (_error) => {
}
};
}
const closePromise = promiseWithResolvers();
const socketId = generateConnSocketId();
return {
onOpen: (_evt, ws) => {
actor2.rLog.debug("actor websocket open");
(async () => {
try {
let conn;
actor2.rLog.debug({
msg: connId ? "websocket reconnection attempt" : "new websocket connection",
connId,
actorId
});
conn = await actor2.createConn(
{
socketId,
driverState: {
[0 /* WEBSOCKET */]: {
encoding,
websocket: ws,
closePromise
}
}
},
parameters,
req,
connId,
connToken
);
handlersResolve({ conn, actor: actor2, connId: conn.id });
} catch (error) {
handlersReject(error);
const { code } = deconstructError(
error,
actor2.rLog,
{
wsEvent: "open"
},
exposeInternalError
);
ws.close(1011, code);
}
})();
},
onMessage: (evt, ws) => {
handlersPromise.then(({ conn, actor: actor3 }) => {
actor3.rLog.debug({ msg: "received message" });
const value = evt.data.valueOf();
parseMessage(value, {
encoding,
maxIncomingMessageSize: runConfig.maxIncomingMessageSize
}).then((message) => {
actor3.processMessage(message, conn).catch((error) => {
const { code } = deconstructError(
error,
actor3.rLog,
{
wsEvent: "message"
},
exposeInternalError
);
ws.close(1011, code);
});
}).catch((error) => {
const { code } = deconstructError(
error,
actor3.rLog,
{
wsEvent: "message"
},
exposeInternalError
);
ws.close(1011, code);
});
}).catch((error) => {
const { code } = deconstructError(
error,
actor2.rLog,
{
wsEvent: "message"
},
exposeInternalError
);
ws.close(1011, code);
});
},
onClose: (event, ws) => {
handlersReject(`WebSocket closed (${event.code}): ${event.reason}`);
closePromise.resolve();
if (event.wasClean) {
actor2.rLog.info({
msg: "websocket closed",
code: event.code,
reason: event.reason,
wasClean: event.wasClean
});
} else {
actor2.rLog.warn({
msg: "websocket closed",
code: event.code,
reason: event.reason,
wasClean: event.wasClean
});
}
ws.close(1e3, "hack_force_close");
handlersPromise.then(({ conn, actor: actor3 }) => {
const wasClean = event.wasClean || event.code === 1e3;
actor3.__connDisconnected(conn, wasClean, socketId);
}).catch((error) => {
deconstructError(
error,
actor2.rLog,
{ wsEvent: "close" },
exposeInternalError
);
});
},
onError: (_error) => {
try {
actor2.rLog.warn({ msg: "websocket error" });
} catch (error) {
deconstructError(
error,
actor2.rLog,
{ wsEvent: "error" },
exposeInternalError
);
}
}
};
}
async function handleSseConnect(c, _runConfig, actorDriver, actorId) {
c.header("Content-Encoding", "Identity");
const encoding = getRequestEncoding(c.req);
const parameters = getRequestConnParams(c.req);
const socketId = generateConnSocketId();
const connId = c.req.header(HEADER_CONN_ID);
const connToken = c.req.header(HEADER_CONN_TOKEN);
return streamSSE(c, async (stream) => {
let actor2;
let conn;
try {
actor2 = await actorDriver.loadActor(actorId);
actor2.rLog.debug({
msg: connId ? "sse reconnection attempt" : "sse open",
connId
});
conn = await actor2.createConn(
{
socketId,
driverState: {
[1 /* SSE */]: {
encoding,
stream
}
}
},
parameters,
c.req.raw,
connId,
connToken
);
const abortResolver = promiseWithResolvers();
stream.onAbort(() => {
});
c.req.raw.signal.addEventListener("abort", async () => {
invariant(actor2, "actor should exist");
const rLog = actor2.rLog ?? loggerWithoutContext();
try {
rLog.debug("sse stream aborted");
if (conn) {
actor2.__connDisconnected(conn, false, socketId);
}
abortResolver.resolve(void 0);
} catch (error) {
rLog.error({ msg: "error closing sse connection", error });
abortResolver.resolve(void 0);
}
});
while (true) {
if (stream.closed || stream.aborted) {
actor2 == null ? void 0 : actor2.rLog.debug({
msg: "sse stream closed",
closed: stream.closed,
aborted: stream.aborted
});
break;
}
await stream.writeSSE({ event: "ping", data: "" });
await stream.sleep(SSE_PING_INTERVAL);
}
} catch (error) {
loggerWithoutContext().error({ msg: "error in sse connection", error });
if (conn && actor2 !== void 0) {
actor2.__connDisconnected(conn, false, socketId);
}
stream.close();
}
});
}
async function handleAction(c, _runConfig, actorDriver, actionName, actorId) {
const encoding = getRequestEncoding(c.req);
const parameters = getRequestConnParams(c.req);
const arrayBuffer = await c.req.arrayBuffer();
const request = deserializeWithEncoding(
encoding,
new Uint8Array(arrayBuffer),
HTTP_ACTION_REQUEST_VERSIONED
);
const actionArgs = cbor.decode(new Uint8Array(request.args));
const socketId = generateConnSocketId();
let actor2;
let conn;
let output;
try {
actor2 = await actorDriver.loadActor(actorId);
actor2.rLog.debug({ msg: "handling action", actionName, encoding });
conn = await actor2.createConn(
{
socketId,
driverState: { [2 /* HTTP */]: {} }
},
parameters,
c.req.raw
);
const ctx = new ActionContext(actor2.actorContext, conn);
output = await actor2.executeAction(ctx, actionName, actionArgs);
} finally {
if (conn) {
actor2 == null ? void 0 : actor2.__connDisconnected(conn, true, socketId);
}
}
const responseData = {
output: bufferToArrayBuffer(cbor.encode(output))
};
const serialized = serializeWithEncoding(
encoding,
responseData,
HTTP_ACTION_RESPONSE_VERSIONED
);
return c.body(serialized, 200, {
"Content-Type": contentTypeForEncoding(encoding)
});
}
async function handleConnectionMessage(c, _runConfig, actorDriver, connId, connToken, actorId) {
const encoding = getRequestEncoding(c.req);
const arrayBuffer = await c.req.arrayBuffer();
const message = deserializeWithEncoding(
encoding,
new Uint8Array(arrayBuffer),
TO_SERVER_VERSIONED
);
const actor2 = await actorDriver.loadActor(actorId);
const conn = actor2.conns.get(connId);
if (!conn) {
throw new ConnNotFound(connId);
}
if (conn._token !== connToken) {
throw new IncorrectConnToken();
}
await actor2.processMessage(message, conn);
return c.json({});
}
async function handleConnectionClose(c, _runConfig, actorDriver, connId, connToken, actorId) {
var _a;
const actor2 = await actorDriver.loadActor(actorId);
const conn = actor2.conns.get(connId);
if (!conn) {
throw new ConnNotFound(connId);
}
if (conn._token !== connToken) {
throw new IncorrectConnToken();
}
if (!((_a = conn.__socket) == null ? void 0 : _a.driverState) || !(1 /* SSE */ in conn.__socket.driverState)) {
throw new UserError(
"Connection close is only supported for SSE connections"
);
}
await conn.disconnect("Connection closed by client request");
return c.json({});
}
async function handleRawWebSocketHandler(req, path4, actorDriver, actorId) {
const actor2 = await actorDriver.loadActor(actorId);
return {
onOpen: (_evt, ws) => {
const adapter = new HonoWebSocketAdapter(ws);
ws.__adapter = adapter;
const url = new URL(path4, "http://actor");
const pathname = url.pathname.replace(/^\/raw\/websocket\/?/, "") || "/";
const normalizedPath = (pathname.startsWith("/") ? pathname : "/" + pathname) + url.search;
let newRequest;
if (req) {
newRequest = new Request(`http://actor${normalizedPath}`, req);
} else {
newRequest = new Request(`http://actor${normalizedPath}`, {
method: "GET"
});
}
actor2.rLog.debug({
msg: "rewriting websocket url",
from: path4,
to: newRequest.url,
pathname: url.pathname,
search: url.search,
normalizedPath
});
actor2.handleWebSocket(adapter, {
request: newRequest
});
},
onMessage: (event, ws) => {
const adapter = ws.__adapter;
if (adapter) {
adapter._handleMessage(event);
}
},
onClose: (evt, ws) => {
const adapter = ws.__adapter;
if (adapter) {
adapter._handleClose((evt == null ? void 0 : evt.code) || 1006, (evt == null ? void 0 : evt.reason) || "");
}
},
onError: (error, ws) => {
const adapter = ws.__adapter;
if (adapter) {
adapter._handleError(error);
}
}
};
}
function getRequestEncoding(req) {
const encodingParam = req.header(HEADER_ENCODING);
if (!encodingParam) {
throw new InvalidEncoding("undefined");
}
const result = EncodingSchema.safeParse(encodingParam);
if (!result.success) {
throw new InvalidEncoding(encodingParam);
}
return result.data;
}
function getRequestExposeInternalError(_req) {
return false;
}
function getRequestConnParams(req) {
const paramsParam = req.header(HEADER_CONN_PARAMS);
if (!paramsParam) {
return null;
}
try {
return JSON.parse(paramsParam);
} catch (err) {
throw new InvalidParams(
`Invalid params JSON: ${stringifyError(err)}`
);
}
}
// src/common/router.ts
import * as cbor2 from "cbor-x";
function logger2() {
return getLogger("router");
}
function loggerMiddleware(logger8) {
return async (c, next) => {
const method = c.req.method;
const path4 = c.req.path;
const startTime = Date.now();
await next();
const duration = Date.now() - startTime;
logger8.debug({
msg: "http request",
method,
path: path4,
status: c.res.status,
dt: `${duration}ms`,
reqSize: c.req.header("content-length"),
resSize: c.res.headers.get("content-length"),
userAgent: c.req.header("user-agent")
});
};
}
function handleRouteNotFound(c) {
return c.text("Not Found (RivetKit)", 404);
}
function handleRouteError(error, c) {
const exposeInternalError = getRequestExposeInternalError(c.req.raw);
const { statusCode, group, code, message, metadata } = deconstructError(
error,
logger2(),
{
method: c.req.method,
path: c.req.path
},
exposeInternalError
);
let encoding;
try {
encoding = getRequestEncoding(c.req);
} catch (_) {
encoding = "json";
}
const output = serializeWithEncoding(
encoding,
{
group,
code,
message,
// TODO: Cannot serialize non-binary meta since it requires ArrayBuffer atm
metadata: encodingIsBinary(encoding) ? bufferToArrayBuffer(cbor2.encode(metadata)) : null
},
HTTP_RESPONSE_ERROR_VERSIONED
);
return c.body(output, { status: statusCode });
}
// src/actor/router.ts
function createActorRouter(runConfig, actorDriver, isTest) {
const router = new Hono({ strict: false });
router.use("*", loggerMiddleware(loggerWithoutContext()));
router.get("/", (c) => {
return c.text(
"This is an RivetKit actor.\n\nLearn more at https://rivetkit.org"
);
});
router.get("/health", (c) => {
return c.text("ok");
});
if (isTest) {
router.post("/.test/force-disconnect", async (c) => {
const connId = c.req.query("conn");
if (!connId) {
return c.text("Missing conn query parameter", 400);
}
const actor2 = await actorDriver.loadActor(c.env.actorId);
const conn = actor2.__getConnForId(connId);
if (!conn) {
return c.text(`Connection not found: ${connId}`, 404);
}
const driverState = conn.__driverState;
if (driverState && 0 /* WEBSOCKET */ in driverState) {
const ws = driverState[0 /* WEBSOCKET */].websocket;
ws.raw.terminate();
} else if (driverState && 1 /* SSE */ in driverState) {
const stream = driverState[1 /* SSE */].stream;
stream.abort();
}
return c.json({ success: true });
});
}
router.get(PATH_CONNECT_WEBSOCKET, async (c) => {
var _a;
const upgradeWebSocket = (_a = runConfig.getUpgradeWebSocket) == null ? void 0 : _a.call(runConfig);
if (upgradeWebSocket) {
return upgradeWebSocket(async (c2) => {
const protocols = c2.req.header("sec-websocket-protocol");
let encodingRaw;
let connParamsRaw;
let connIdRaw;
let connTokenRaw;
if (protocols) {
const protocolList = protocols.split(",").map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
connParamsRaw = decodeURIComponent(
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length)
);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_ID)) {
connIdRaw = protocol.substring(WS_PROTOCOL_CONN_ID.length);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_TOKEN)) {
connTokenRaw = protocol.substring(WS_PROTOCOL_CONN_TOKEN.length);
}
}
}
const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : void 0;
return await handleWebSocketConnect(
c2.req.raw,
runConfig,
actorDriver,
c2.env.actorId,
encoding,
connParams,
connIdRaw,
connTokenRaw
);
})(c, noopNext());
} else {
return c.text(
"WebSockets are not enabled for this driver. Use SSE instead.",
400
);
}
});
router.get("/connect/sse", async (c) => {
return handleSseConnect(c, runConfig, actorDriver, c.env.actorId);
});
router.post("/action/:action", async (c) => {
const actionName = c.req.param("action");
return handleAction(c, runConfig, actorDriver, actionName, c.env.actorId);
});
router.post("/connections/message", async (c) => {
const connId = c.req.header(HEADER_CONN_ID);
const connToken = c.req.header(HEADER_CONN_TOKEN);
if (!connId || !connToken) {
throw new Error("Missing required parameters");
}
return handleConnectionMessage(
c,
runConfig,
actorDriver,
connId,
connToken,
c.env.actorId
);
});
router.post("/connections/close", async (c) => {
const connId = c.req.header(HEADER_CONN_ID);
const connToken = c.req.header(HEADER_CONN_TOKEN);
if (!connId || !connToken) {
throw new Error("Missing required parameters");
}
return handleConnectionClose(
c,
runConfig,
actorDriver,
connId,
connToken,
c.env.actorId
);
});
router.all("/raw/http/*", async (c) => {
const actor2 = await actorDriver.loadActor(c.env.actorId);
const url = new URL(c.req.url);
const originalPath = url.pathname.replace(/^\/raw\/http/, "") || "/";
const correctedUrl = new URL(originalPath + url.search, url.origin);
const correctedRequest = new Request(correctedUrl, {
method: c.req.method,
headers: c.req.raw.headers,
body: c.req.raw.body,
duplex: "half"
});
loggerWithoutContext().debug({
msg: "rewriting http url",
from: c.req.url,
to: correctedRequest.url
});
const response = await actor2.handleFetch(correctedRequest, {});
if (!response) {
throw new InternalError("handleFetch returned void unexpectedly");
}
return response;
});
router.get(`${PATH_RAW_WEBSOCKET_PREFIX}*`, async (c) => {
var _a;
const upgradeWebSocket = (_a = runConfig.getUpgradeWebSocket) == null ? void 0 : _a.call(runConfig);
if (upgradeWebSocket) {
return upgradeWebSocket(async (c2) => {
const url = new URL(c2.req.url);
const pathWithQuery = c2.req.path + url.search;
loggerWithoutContext().debug({
msg: "actor router raw websocket",
path: c2.req.path,
url: c2.req.url,
search: url.search,
pathWithQuery
});
return await handleRawWebSocketHandler(
c2.req.raw,
pathWithQuery,
actorDriver,
c2.env.actorId
);
})(c, noopNext());
} else {
return c.text(
"WebSockets are not enabled for this driver. Use SSE instead.",
400
);
}
});
if (isInspectorEnabled(runConfig, "actor")) {
router.route(
"/inspect",
new Hono().use(
cors(runConfig.inspector.cors),
secureInspector(runConfig),
async (c, next) => {
const inspector = (await actorDriver.loadActor(c.env.actorId)).inspector;
invariant2(inspector, "inspector not supported on this platform");
c.set("inspector", inspector);
return next();
}
).route("/", createActorInspectorRouter())
);
}
router.notFound(handleRouteNotFound);
router.onError(handleRouteError);
return router;
}
// src/actor/mod.ts
function actor(input) {
const config2 = ActorConfigSchema.parse(input);
return new ActorDefinition(config2);
}
// src/common/inline-websocket-adapter2.ts
import { WSContext } from "hono/ws";
function logger3() {
return getLogger("fake-event-source2");
}
var InlineWebSocketAdapter2 = class {
// WebSocket readyState values
CONNECTING = 0;
OPEN = 1;
CLOSING = 2;
CLOSED = 3;
// Private properties
#handler;
#wsContext;
#readyState = 0;
// Start in CONNECTING state
#queuedMessages = [];
// Event buffering is needed since events can be fired
// before JavaScript has a chance to add event listeners (e.g. within the same tick)
#bufferedEvents = [];
// Event listeners with buffering
#eventListeners = /* @__PURE__ */ new Map();
constructor(handler) {
this.#handler = handler;
this.#wsContext = new WSContext({
raw: this,
send: (data) => {
logger3().debug({ msg: "WSContext.send called" });
this.#handleMessage(data);
},
close: (code, reason) => {
logger3().debug({ msg: "WSContext.close called", code, reason });
this.#handleClose(code || 1e3, reason || "");
},
// Set readyState to 1 (OPEN) since handlers expect an open connection
readyState: 1
});
this.#initialize();
}
get readyState() {
return this.#readyState;
}
get binaryType() {
return "arraybuffer";
}
set binaryType(value) {
}
get bufferedAmount() {
return 0;
}
get extensions() {
return "";
}
get protocol() {
return "";
}
get url() {
return "";
}
send(data) {
logger3().debug({ msg: "send called", readyState: this.readyState });
if (this.readyState !== this.OPEN) {
const error = new Error("WebSocket is not open");
logger3().warn({
msg: "cannot send message, websocket not open",
readyState: this.readyState,
dataType: typeof data,
dataLength: typeof data === "string" ? data.length : "binary",
error
});
this.#fireError(error);
return;
}
this.#handler.onMessage({ data }, this.#wsContext);
}
/**
* Closes the connection
*/
close(code = 1e3, reason = "") {
if (this.readyState === this.CLOSED || this.readyState === this.CLOSING) {
return;
}
logger3().debug({ msg: "closing fake websocket", code, reason });
this.#readyState = this.CLOSING;
try {
this.#handler.onClose({ code, reason, wasClean: true }, this.#wsContext);
} catch (err) {
logger3().error({ msg: "error closing websocket", error: err });
} finally {
this.#readyState = this.CLOSED;
const closeEvent = {
type: "close",
wasClean: code === 1e3,
code,
reason,
target: this,
currentTarget: this
};
this.#fireClose(closeEvent);
}
}
/**
* Initialize the connection with the handler
*/
async #initialize() {
try {
logger3().debug({ msg: "fake websocket initializing" });
logger3().debug({ msg: "calling handler.onOpen with WSContext" });
this.#handler.onOpen(void 0, this.#wsContext);
this.#readyState = this.OPEN;
logger3().debug({ msg: "fake websocket initialized and now OPEN" });
this.#fireOpen();
if (this.#queuedMessages.length > 0) {
if (this.readyState !== this.OPEN) {
logger3().warn({
msg: "socket no longer open, dropping queued messages"
});
return;
}
logger3().debug({
msg: `now processing ${this.#queuedMessages.length} queued messages`
});
const messagesToProcess = [...this.#queuedMessages];
this.#queuedMessages = [];
for (const message of messagesToProcess) {
logger3().debug({ msg: "processing queued message" });
this.#handleMessage(message);
}
}
} catch (err) {
logger3().error({
msg: "error opening fake websocket",
error: err,
errorMessage: err instanceof Error ? err.message : String(err),
stack: err instanceof Error ? err.stack : void 0
});
this.#fireError(err);
this.close(1011, "Internal error during initialization");
}
}
/**
* Handle messages received from the server via the WSContext
*/
#handleMessage(data) {
if (this.readyState !== this.OPEN) {
logger3().debug({
msg: "message received before socket is OPEN, queuing",
readyState: this.readyState,
dataType: typeof data,
dataLength: typeof data === "string" ? data.length : data instanceof ArrayBuffer ? data.byteLength : data instanceof Uint8Array ? data.byteLength : "unknown"
});
this.#queuedMessages.push(data);
return;
}
logger3().debug({
msg: "fake websocket received message from server",
dataType: typeof data,
dataLength: typeof data === "string" ? data.length : data instanceof ArrayBuffer ? data.byteLength : data instanceof Uint8Array ? data.byteLength : "unknown"
});
const event = {
type: "message",
data,
target: this,
currentTarget: this
};
this.#dispatchEvent("message", event);
}
#handleClose(code, reason) {
if (this.readyState === this.CLOSED) return;
this.#readyState = this.CLOSED;
const event = {
type: "close",
code,
reason,
wasClean: code === 1e3,
target: this,
currentTarget: this
};
this.#dispatchEvent("close", event);
}
addEventListener(type, listener) {
if (!this.#eventListeners.has(type)) {
this.#eventListeners.set(type, []);
}
this.#eventListeners.get(type).push(listener);
this.#flushBufferedEvents(type);
}
removeEventListener(type, listener) {
const listeners = this.#eventListeners.get(type);
if (listeners) {
const index = listeners.indexOf(listener);
if (index !== -1) {
listeners.splice(index, 1);
}
}
}
#dispatchEvent(type, event) {
const listeners = this.#eventListeners.get(type);
if (listeners && listeners.length > 0) {
logger3().debug(
`dispatching ${type} event to ${listeners.length} listeners`
);
for (const listener of listeners) {
try {
listener(event);
} catch (err) {
logger3().error({
msg: `error in ${type} event listener`,
error: err
});
}
}
} else {
logger3().debug({
msg: `no ${type} listeners registered, buffering event`
});
this.#bufferedEvents.push({ type, event });
}
switch (type) {
case "open":
if (this.#onopen) {
try {
this.#onopen(event);
} catch (error) {
logger3().error({ msg: "error in onopen handler", error });
}
}
break;
case "close":
if (this.#onclose) {
try {
this.#onclose(event);
} catch (error) {
logger3().error({ msg: "error in onclose handler", error });
}
}
break;
case "error":
if (this.#onerror) {
try {
this.#onerror(event);
} catch (error) {
logger3().error({ msg: "error in onerror handler", error });
}
}
break;
case "message":
if (this.#onmessage) {
try {
this.#onmessage(event);
} catch (error) {
logger3().error({ msg: "error in onmessage handler", error });
}
}
break;
}
}
dispatchEvent(event) {
this.#dispatchEvent(event.type, event);
return true;
}
#flushBufferedEvents(type) {
const eventsToFlush = this.#bufferedEvents.filter(
(buffered) => buffered.type === type
);
this.#bufferedEvents = this.#bufferedEvents.filter(
(buffered) => buffered.type !== type
);
for (const { event } of eventsToFlush) {
this.#dispatchEvent(type, event);
}
}
#fireOpen() {
try {
const event = {
type: "open",
target: this,
currentTarget: this
};
this.#dispatchEvent("open", event);
} catch (err) {
logger3().error({ msg: "error in open event", error: err });
}
}
#fireClose(event) {
try {
this.#dispatchEvent("close", event);
} catch (err) {
logger3().error({ msg: "error in close event", error: err });
}
}
#fireError(error) {
try {
const event = {
type: "error",
target: this,
currentTarget: this,
error,
message: error instanceof Error ? error.message : String(error)
};
this.#dispatchEvent("error", event);
} catch (err) {
logger3().error({ msg: "error in error event", error: err });
}
logger3().error({ msg: "websocket error", error });
}
// Event handler properties with getters/setters
#onopen = null;
#onclose = null;
#onerror = null;
#onmessage = null;
get onopen() {
return this.#onopen;
}
set onopen(handler) {
this.#onopen = handler;
}
get onclose() {
return this.#onclose;
}
set onclose(handler) {
this.#onclose = handler;
}
get onerror() {
return this.#onerror;
}
set onerror(handler) {
this.#onerror = handler;
}
get onmessage() {
return this.#onmessage;
}
set onmessage(handler) {
this.#onmessage = handler;
}
};
// src/drivers/engine/actor-driver.ts
import { Runner } from "@rivetkit/engine-runner";
import * as cbor3 from "cbor-x";
import { streamSSE as streamSSE2 } from "hono/streaming";
import { WSContext as WSContext2 } from "hono/ws";
import invariant3 from "invariant";
// src/drivers/engine/kv.ts
var KEYS = {
PERSIST_DATA: Uint8Array.from([1])
};
// src/drivers/engine/log.ts
function logger4() {
return getLogger("driver-engine");
}
// src/drivers/engine/actor-driver.ts
var EngineActorDriver = class {
#registryConfig;
#runConfig;
#managerDriver;
#inlineClient;
#runner;
#actors = /* @__PURE__ */ new Map();
#actorRouter;
#version = 1;
// Version for the runner protocol
#alarmTimeout;
#runnerStarted = Promise.withResolvers();
#runnerStopped = Promise.withResolvers();
constructor(registryConfig, runConfig, managerDriver, inlineClient) {
this.#registryConfig = registryConfig;
this.#runConfig = runConfig;
this.#managerDriver = managerDriver;
this.#inlineClient = inlineClient;
const token = runConfig.token ?? runConfig.token;
if (token && runConfig.inspector && runConfig.inspector.enabled) {
runConfig.inspector.token = () => token;
}
this.#actorRouter = createActorRouter(
runConfig,
this,
registryConfig.test.enabled
);
let hasDisconnected = false;
const engineRunnerConfig = {
version: this.#version,
endpoint: getEndpoint(runConfig),
token,
namespace: runConfig.namespace ?? runConfig.namespace,
totalSlots: runConfig.totalSlots ?? runConfig.totalSlots,
runnerName: runConfig.runnerName ?? runConfig.runnerName,
runnerKey: runConfig.runnerKey,
metadata: {
inspectorToken: this.#runConfig.inspector.token()
},
prepopulateActorNames: Object.fromEntries(
Object.keys(this.#registryConfig.use).map((name) => [
name,
{ metadata: {} }
])
),
onConnected: () => {
if (hasDisconnected) {
logger4().info({
msg: "runner reconnected",
namespace: this.#runConfig.namespace,
runnerName: this.#runConfig.runnerName
});
} else {
logger4().debug({
msg: "runner connected",
namespace: this.#runConfig.namespace,
runnerName: this.#runConfig.runnerName
});
}
this.#runnerStarted.resolve(void 0);
},
onDisconnected: () => {
logger4().warn({
msg: "runner disconnected",
namespace: this.#runConfig.namespace,
runnerName: this.#runConfig.runnerName
});
hasDisconnected = true;
},
onShutdown: () => {
this.#runnerStopped.resolve(void 0);
},
fetch: this.#runnerFetch.bind(this),
websocket: this.#runnerWebSocket.bind(this),
onActorStart: this.#runnerOnActorStart.bind(this),
onActorStop: this.#runnerOnActorStop.bind(this),
logger: getLogger("engine-runner")
};
this.#runner = new Runner(engineRunnerConfig);
this.#runner.start();
logger4().debug({
msg: "engine runner started",
endpoint: runConfig.endpoint,
namespace: runConfig.namespace,
runnerName: runConfig.runnerName
});
}
async #loadActorHandler(actorId) {
const handler = this.#actors.get(actorId);
if (!handler) throw new Error(`Actor handler does not exist ${actorId}`);
if (handler.actorStartPromise) await handler.actorStartPromise.promise;
if (!handler.actor) throw new Error("Actor should be loaded");
return handler;
}
async loadActor(actorId) {
const handler = await this.#loadActorHandler(actorId);
if (!handler.actor) throw new Error(`Actor ${actorId} failed to load`);
return handler.actor;
}
getContext(actorId) {
return {};
}
async readPersistedData(actorId) {
const handler = this.#actors.get(actorId);
if (!handler) throw new Error(`Actor ${actorId} not loaded`);
if (handler.persistedData) return handler.persistedData;
const [value] = await this.#runner.kvGet(actorId, [KEYS.PERSIST_DATA]);
if (value !== null) {
handler.persistedData = value;
return value;
} else {
return void 0;
}
}
async writePersistedData(actorId, data) {
const handler = this.#actors.get(actorId);
if (!handler) throw new Error(`Actor ${actorId} not loaded`);
handler.persistedData = data;
await this.#runner.kvPut(actorId, [[KEYS.PERSIST_DATA, data]]);
}
async setAlarm(actor2, timestamp) {
if (this.#alarmTimeout) {
this.#alarmTimeout.abort();
this.#alarmTimeout = void 0;
}
const delay = Math.max(0, timestamp - Date.now());
this.#alarmTimeout = setLongTimeout(() => {
actor2._onAlarm();
this.#alarmTimeout = void 0;
}, delay);
this.#runner.setAlarm(actor2.id, timestamp);
}
async getDatabase(_actorId) {
return void 0;
}
// Runner lifecycle callbacks
async #runnerOnActorStart(actorId, generation, runConfig) {
var _a;
logger4().debug({
msg: "runner actor starting",
actorId,
name: runConfig.name,
key: runConfig.key,
generation
});
let input;
if (runConfig.input) {
input = cbor3.decode(runConfig.input);
}
let handler = this.#actors.get(actorId);
if (!handler) {
handler = {
actorStartPromise: promiseWithResolvers(),
persistedData: serializeEmptyPersistData(input)
};
this.#actors.set(actorId, handler);
}
const name = runConfig.name;
invariant3(runConfig.key, "actor should have a key");
const key = deserializeActorKey(runConfig.key);
const definition = lookupInRegistry(this.#registryConfig, runConfig.name);
handler.actor = definition.instantiate();
await handler.actor.start(
this,
this.#inlineClient,
actorId,
name,
key,
"unknown"
// TODO: Add regions
);
(_a = handler.actorStartPromise) == null ? void 0 : _a.resolve();
handler.actorStartPromise = void 0;
logger4().debug({ msg: "runner actor started", actorId, name, key });
}
async #runnerOnActorStop(actorId, generation) {
logger4().debug({ msg: "runner actor stopping", actorId, generation });
const handler = this.#actors.get(actorId);
if (handler == null ? void 0 : handler.actor) {
await handler.actor._stop();
this.#actors.delete(actorId);
}
logger4().debug({ msg: "runner actor stopped", actorId });
}
async #runnerFetch(actorId, request) {
logger4().debug({
msg: "runner fetch",
actorId,
url: request.url,
method: request.method
});
return await this.#actorRouter.fetch(request, { actorId });
}
async #runnerWebSocket(actorId, websocketRaw, request) {
const websocket = websocketRaw;
logger4().debug({ msg: "runner websocket", actorId, url: request.url });
const url = new URL(request.url);
const protocols = request.headers.get("sec-websocket-protocol");
if (protocols === null)
throw new Error(`Missing sec-websocket-protocol header`);
let encodingRaw;
let connParamsRaw;
if (protocols) {
const protocolList = protocols.split(",").map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
connParamsRaw = decodeURIComponent(
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length)
);
}
}
}
const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : void 0;
let wsHandlerPromise;
if (url.pathname === PATH_CONNECT_WEBSOCKET) {
wsHandlerPromise = handleWebSocketConnect(
request,
this.#runConfig,
this,
actorId,
encoding,
connParams,
// Extract connId and connToken from protocols if needed
void 0,
void 0
);
} else if (url.pathname.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) {
wsHandlerPromise = handleRawWebSocketHandler(
request,
url.pathname + url.search,
this,
actorId
);
} else {
throw new Error(`Unreachable path: ${url.pathname}`);
}
const wsContext = new WSContext2(websocket);
wsHandlerPromise.catch((err) => {
logger4().error({ msg: "building websocket handlers errored", err });
wsContext.close(1011, `${err}`);
});
if (websocket.readyState === 1) {
wsHandlerPromise.then((x) => {
var _a;
return (_a = x.onOpen) == null ? void 0 : _a.call(x, new Event("open"), wsContext);
});
} else {
websocket.addEventListener("open", (event) => {
wsHandlerPromise.then((x) => {
var _a;
return (_a = x.onOpen) == null ? void 0 : _a.call(x, event, wsContext);
});
});
}
websocket.addEventListener("message", (event) => {
wsHandlerPromise.then((x) => {
var _a;
return (_a = x.onMessage) == null ? void 0 : _a.call(x, event, wsContext);
});
});
websocket.addEventListener("close", (event) => {
wsHandlerPromise.then((x) => {
var _a;
return (_a = x.onClose) == null ? void 0 : _a.call(x, event, wsContext);
});
});
websocket.addEventListener("error", (event) => {
wsHandlerPromise.then((x) => {
var _a;
return (_a = x.onError) == null ? void 0 : _a.call(x, event, wsContext);
});
});
}
async sleep(actorId) {
this.#runner.sleep