@rivetkit/redis
Version:
_Lightweight Libraries for Backends_
1,598 lines (1,576 loc) • 47 kB
JavaScript
import {
ActorPeer,
RedisActorDriver,
logger
} from "./chunk-W4M5YHOW.js";
import {
RedisManagerDriver
} from "./chunk-4ITC5TMQ.js";
import {
InternalError,
KEYS,
PUBSUB,
assertUnreachable
} from "./chunk-K6L53HR4.js";
// src/mod.ts
import {
createActorRouter,
createClientWithDriver,
createInlineClientDriver
} from "@rivetkit/core";
// src/config.ts
import { Redis } from "ioredis";
import { z as z2 } from "zod";
// src/coordinate/config.ts
import { z } from "zod";
var CoordinateDriverConfig = z.object({
actorPeer: z.object({
leaseDuration: z.number().default(3e3),
renewLeaseGrace: z.number().default(1500),
checkLeaseInterval: z.number().default(1e3),
checkLeaseJitter: z.number().default(500),
messageAckTimeout: z.number().default(1e3)
})
});
// src/config.ts
var RedisDriverConfig = CoordinateDriverConfig.extend({
redis: z2.custom((val) => val instanceof Redis, {
message: "Must be an instance of Redis"
}).optional().default(
() => new Redis({
host: process.env.REDIS_HOST ?? "localhost",
port: process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT, 10) : 6379,
password: process.env.REDIS_PASSWORD
})
),
keyPrefix: z2.string().default(() => process.env.REDIS_KEY_PREFIX ?? "rivetkit")
});
// src/coordinate.ts
import * as cbor from "cbor-x";
import dedent from "dedent";
var RedisCoordinateDriver = class {
#driverConfig;
#redis;
#nodeSub;
constructor(driverConfig, redis) {
this.#driverConfig = driverConfig;
this.#redis = redis;
this.#defineRedisScripts();
}
async createNodeSubscriber(selfNodeId, callback) {
this.#nodeSub = this.#redis.duplicate();
this.#nodeSub.on(
"messageBuffer",
(_channel, messageRaw) => {
const message = cbor.decode(messageRaw);
callback(message);
}
);
await this.#nodeSub.subscribe(
PUBSUB.node(this.#driverConfig.keyPrefix, selfNodeId)
);
}
async publishToNode(targetNodeId, message) {
await this.#redis.publish(
PUBSUB.node(this.#driverConfig.keyPrefix, targetNodeId),
cbor.encode(message)
);
}
async getActorLeader(actorId) {
const [metadata, nodeId] = await this.#redis.mget([
// TODO: Use exists in pipeline instead of getting all data
KEYS.ACTOR.metadata(this.#driverConfig.keyPrefix, actorId),
KEYS.ACTOR.LEASE.node(this.#driverConfig.keyPrefix, actorId)
]);
if (!metadata) {
return { actor: void 0 };
}
return {
actor: {
leaderNodeId: nodeId || void 0
}
};
}
async startActorAndAcquireLease(actorId, selfNodeId, leaseDuration) {
const execRes = await this.#redis.multi().getBuffer(KEYS.ACTOR.metadata(this.#driverConfig.keyPrefix, actorId)).actorPeerAcquireLease(
KEYS.ACTOR.LEASE.node(this.#driverConfig.keyPrefix, actorId),
selfNodeId,
leaseDuration
).exec();
if (!execRes) {
throw new Error("Redis transaction failed");
}
const [[getErr, getRes], [leaseErr, leaseRes]] = execRes;
if (getErr) throw new Error(`Redis GET error: ${getErr}`);
if (leaseErr) throw new Error(`Redis acquire lease error: ${leaseErr}`);
const metadataRaw = getRes;
const leaderNodeId = leaseRes;
if (!metadataRaw) {
return { actor: void 0 };
}
if (!metadataRaw)
throw new Error("Actor should have metadata if initialized.");
const metadata = cbor.decode(metadataRaw);
return {
actor: {
name: metadata.name,
key: metadata.key,
leaderNodeId
}
};
}
async extendLease(actorId, selfNodeId, leaseDuration) {
const res = await this.#redis.actorPeerExtendLease(
KEYS.ACTOR.LEASE.node(this.#driverConfig.keyPrefix, actorId),
selfNodeId,
leaseDuration
);
return {
leaseValid: res === 1
};
}
async attemptAcquireLease(actorId, selfNodeId, leaseDuration) {
const newLeaderNodeId = await this.#redis.actorPeerAcquireLease(
KEYS.ACTOR.LEASE.node(this.#driverConfig.keyPrefix, actorId),
selfNodeId,
leaseDuration
);
return {
newLeaderNodeId
};
}
async releaseLease(actorId, nodeId) {
await this.#redis.actorPeerReleaseLease(
KEYS.ACTOR.LEASE.node(this.#driverConfig.keyPrefix, actorId),
nodeId
);
}
#defineRedisScripts() {
this.#redis.defineCommand("actorPeerAcquireLease", {
numberOfKeys: 1,
lua: dedent`
-- Get the current value of the key
local currentValue = redis.call("get", KEYS[1])
-- Return the current value if an entry already exists
if currentValue then
return currentValue
end
-- Create an entry for the provided key
redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2])
-- Return the value to indicate the entry was added
return ARGV[1]
`
});
this.#redis.defineCommand("actorPeerExtendLease", {
numberOfKeys: 1,
lua: dedent`
-- Return 0 if an entry exists with a different lease holder
if redis.call("get", KEYS[1]) ~= ARGV[1] then
return 0
end
-- Update the entry for the provided key
redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2])
-- Return 1 to indicate the entry was updated
return 1
`
});
this.#redis.defineCommand("actorPeerReleaseLease", {
numberOfKeys: 1,
lua: dedent`
-- Only remove the entry for this lock value
if redis.call("get", KEYS[1]) == ARGV[1] then
redis.pcall("del", KEYS[1])
return 1
end
-- Return 0 if no entry was removed.
return 0
`
});
}
};
// src/coordinate/node/mod.ts
import invariant2 from "invariant";
// src/coordinate/relay-conn.ts
import {
generateConnId,
generateConnToken
} from "@rivetkit/core";
// src/coordinate/node/message.ts
import pRetry, { AbortError } from "p-retry";
async function publishMessageToLeader(registryConfig, driverConfig, CoordinateDriver, globalState, actorId, message, signal) {
message.n = globalState.nodeId;
const messageId = crypto.randomUUID();
message.m = messageId;
await pRetry(
() => publishMessageToLeaderInner(
registryConfig,
driverConfig,
CoordinateDriver,
globalState,
actorId,
messageId,
message,
signal
),
{
signal,
minTimeout: 1e3,
retries: 5,
onFailedAttempt: (error) => {
logger().warn("error publishing message", {
attempt: error.attemptNumber,
error: error.message
});
}
}
);
}
async function publishMessageToLeaderNoRetry(registryConfig, driverConfig, CoordinateDriver, globalState, actorId, message, signal) {
message.n = globalState.nodeId;
const messageId = crypto.randomUUID();
message.m = messageId;
try {
await publishMessageToLeaderInner(
registryConfig,
driverConfig,
CoordinateDriver,
globalState,
actorId,
messageId,
message,
signal
);
} catch (error) {
if (error instanceof Error) {
if (error.message === "Actor not initialized") {
throw new LeaderChangedError("Actor not found");
} else if (error.message === "actor not leased, may be transferring leadership") {
throw new LeaderChangedError("Leader is changing");
} else if (error.message === "Ack timed out") {
throw new LeaderChangedError("Leader not responding");
}
}
throw error;
}
}
var LeaderChangedError = class extends Error {
constructor(message) {
super(message);
this.name = "LeaderChangedError";
}
};
async function publishMessageToLeaderInner(registryConfig, driverConfig, CoordinateDriver, globalState, actorId, messageId, message, signal) {
const { actor } = await CoordinateDriver.getActorLeader(actorId);
if (!actor) throw new AbortError("Actor not initialized");
if (!actor.leaderNodeId) {
throw new Error("actor not leased, may be transferring leadership");
}
logger().debug("found actor leader node", { nodeId: actor.leaderNodeId });
const {
promise: ackPromise,
resolve: ackResolve,
reject: ackReject
} = Promise.withResolvers();
globalState.messageAckResolvers.set(messageId, ackResolve);
const signalListener = () => ackReject(new AbortError("Aborted"));
signal == null ? void 0 : signal.addEventListener("abort", signalListener);
const timeoutId = setTimeout(
() => ackReject(new Error("Ack timed out")),
driverConfig.actorPeer.messageAckTimeout
);
try {
await CoordinateDriver.publishToNode(actor.leaderNodeId, message);
logger().debug("waiting for message ack", { messageId });
await ackPromise;
logger().debug("received message ack", { messageId });
} finally {
globalState.messageAckResolvers.delete(messageId);
signal == null ? void 0 : signal.removeEventListener("abort", signalListener);
clearTimeout(timeoutId);
}
}
// src/coordinate/relay-conn.ts
var RelayConn = class {
#registryConfig;
#runConfig;
#driverConfig;
#coordinateDriver;
#actorDriver;
#inlineClient;
#globalState;
#driver;
#actorId;
#actorPeer;
#connId;
#connToken;
#disposed = false;
#abortController = new AbortController();
get actorId() {
return this.#actorId;
}
get connId() {
if (!this.#connId) throw new InternalError("Missing connId");
return this.#connId;
}
get connToken() {
if (!this.#connToken) throw new InternalError("Missing connToken");
return this.#connToken;
}
constructor(registryConfig, runConfig, driverConfig, actorDriver, inlineClient, coordinateDriver, globalState, driver, actorId) {
this.#registryConfig = registryConfig;
this.#runConfig = runConfig;
this.#driverConfig = driverConfig;
this.#coordinateDriver = coordinateDriver;
this.#actorDriver = actorDriver;
this.#inlineClient = inlineClient;
this.#driver = driver;
this.#globalState = globalState;
this.#actorId = actorId;
}
async start() {
const connId = generateConnId();
const connToken = generateConnToken();
this.#connId = connId;
this.#connToken = connToken;
logger().debug("starting relay connection", {
actorId: this.#actorId,
connId: this.#connId
});
this.#actorPeer = await ActorPeer.acquire(
this.#registryConfig,
this.#runConfig,
this.#driverConfig,
this.#actorDriver,
this.#inlineClient,
this.#coordinateDriver,
this.#globalState,
this.#actorId,
connId
);
this.#globalState.relayConns.set(connId, this);
}
async publishMessageToleader(message, retry) {
if (this.#disposed) {
logger().warn(
"attempted to call sendMessageToLeader on disposed RelayConn"
);
return;
}
if (retry) {
await publishMessageToLeader(
this.#registryConfig,
this.#driverConfig,
this.#coordinateDriver,
this.#globalState,
this.#actorId,
message,
this.#abortController.signal
);
} else {
await publishMessageToLeaderNoRetry(
this.#registryConfig,
this.#driverConfig,
this.#coordinateDriver,
this.#globalState,
this.#actorId,
message,
this.#abortController.signal
);
}
}
/**
* Closes the connection and cleans it up.
*
* @param fromLeader - If this message is coming from the leader. This will prevent sending a close message back to the leader.
*/
async disconnect(fromLeader, reason, disconnectMessageToleader) {
var _a, _b;
if (this.#disposed) {
logger().warn("attempted to call disconnect on disposed RelayConn");
return;
}
this.#disposed = true;
this.#abortController.abort();
await this.#driver.disconnect(reason);
if (this.#connId) {
this.#globalState.relayConns.delete(this.#connId);
if (!fromLeader && ((_a = this.#actorPeer) == null ? void 0 : _a.leaderNodeId)) {
if (disconnectMessageToleader) {
await publishMessageToLeader(
this.#registryConfig,
this.#driverConfig,
this.#coordinateDriver,
this.#globalState,
this.#actorId,
disconnectMessageToleader,
void 0
);
}
}
await ((_b = this.#actorPeer) == null ? void 0 : _b.removeConnectionReference(this.#connId));
} else {
logger().warn("disposing connection without connection id");
}
}
};
// src/coordinate/node/message-handlers/fetch.ts
async function handleLeaderFetch(globalState, coordinateDriver, actorRouter, nodeId, fetch) {
if (!nodeId) {
logger().error("node id not provided for leader fetch");
return;
}
try {
const actor = await ActorPeer.getLeaderActor(globalState, fetch.ai);
if (!actor) {
const errorMessage = {
b: {
ffr: {
ri: fetch.ri,
status: 404,
headers: {},
error: "Actor not found"
}
}
};
await coordinateDriver.publishToNode(nodeId, errorMessage);
return;
}
const url = new URL(`http://actor${fetch.url}`);
const body = fetch.body ? fetch.body instanceof Uint8Array ? fetch.body : new TextEncoder().encode(fetch.body) : void 0;
const request = new Request(url, {
method: fetch.method,
headers: fetch.headers,
body
});
const response = await actorRouter.fetch(request, {
actorId: actor.id
});
if (!response) {
throw new Error("handleFetch returned void unexpectedly");
}
const responseHeaders = {};
response.headers.forEach((value, key) => {
const lowerKey = key.toLowerCase();
if (lowerKey !== "content-length" && lowerKey !== "transfer-encoding") {
responseHeaders[key] = value;
}
});
let responseBody;
if (response.body) {
const buffer = await response.arrayBuffer();
responseBody = new Uint8Array(buffer);
}
const responseMessage = {
b: {
ffr: {
ri: fetch.ri,
status: response.status,
headers: responseHeaders,
body: responseBody
}
}
};
await coordinateDriver.publishToNode(nodeId, responseMessage);
} catch (error) {
const errorMessage = {
b: {
ffr: {
ri: fetch.ri,
status: 500,
headers: {},
error: error instanceof Error ? error.message : "Internal server error"
}
}
};
await coordinateDriver.publishToNode(nodeId, errorMessage);
}
}
function handleFollowerFetchResponse(globalState, response) {
const resolver = globalState.fetchResponseResolvers.get(response.ri);
if (resolver) {
resolver(response);
}
}
// src/coordinate/node/message-handlers/websocket-follower.ts
async function handleFollowerWebSocketOpen(globalState, open) {
var _a, _b, _c, _d, _e, _f;
logger().debug("handling follower websocket open", {
websocketId: open.wi,
hasRelayWebSockets: !!globalState.relayWebSockets,
relayWebSocketsSize: ((_a = globalState.relayWebSockets) == null ? void 0 : _a.size) ?? 0,
hasFollowerWebSockets: !!globalState.followerWebSockets,
followerWebSocketsSize: ((_b = globalState.followerWebSockets) == null ? void 0 : _b.size) ?? 0
});
const relayWs = (_c = globalState.relayWebSockets) == null ? void 0 : _c.get(open.wi);
if (relayWs) {
logger().debug("calling _handleOpen on relay websocket", {
websocketId: open.wi
});
relayWs._handleOpen();
return;
}
const followerWs = (_d = globalState.followerWebSockets) == null ? void 0 : _d.get(open.wi);
if (followerWs) {
logger().debug("follower websocket open confirmed by leader", {
websocketId: open.wi
});
return;
}
logger().warn("received websocket open for nonexistent follower websocket", {
websocketId: open.wi,
allRelayWebSocketIds: Array.from(((_e = globalState.relayWebSockets) == null ? void 0 : _e.keys()) ?? []),
allFollowerWebSocketIds: Array.from(
((_f = globalState.followerWebSockets) == null ? void 0 : _f.keys()) ?? []
)
});
}
async function handleFollowerWebSocketMessage(globalState, message) {
var _a, _b;
const ws = globalState.rawWebSockets.get(message.wi);
if (ws) {
if (message.data instanceof Uint8Array) {
ws.send(
message.data.buffer.slice(
message.data.byteOffset,
message.data.byteOffset + message.data.byteLength
)
);
} else {
ws.send(message.data);
}
return;
}
const relayWs = (_a = globalState.relayWebSockets) == null ? void 0 : _a.get(message.wi);
if (relayWs) {
relayWs._handleMessage(message.data, message.binary);
return;
}
const followerWs = (_b = globalState.followerWebSockets) == null ? void 0 : _b.get(message.wi);
if (followerWs) {
logger().debug("forwarding message to follower websocket", {
websocketId: message.wi,
isBinary: message.binary,
dataType: typeof message.data,
dataLength: typeof message.data === "string" ? message.data.length : message.data.byteLength
});
if (message.data instanceof Uint8Array) {
followerWs.ws.send(
message.data.buffer.slice(
message.data.byteOffset,
message.data.byteOffset + message.data.byteLength
)
);
} else {
followerWs.ws.send(message.data);
}
return;
}
logger().warn(
"received websocket message for nonexistent follower websocket",
{
websocketId: message.wi
}
);
}
async function handleFollowerWebSocketClose(globalState, close) {
var _a, _b;
const ws = globalState.rawWebSockets.get(close.wi);
if (ws) {
globalState.rawWebSockets.delete(close.wi);
ws.close(close.code, close.reason);
return;
}
const relayWs = (_a = globalState.relayWebSockets) == null ? void 0 : _a.get(close.wi);
if (relayWs) {
relayWs._handleClose(close.code, close.reason);
globalState.relayWebSockets.delete(close.wi);
return;
}
const followerWs = (_b = globalState.followerWebSockets) == null ? void 0 : _b.get(close.wi);
if (followerWs) {
followerWs.ws.close(close.code, close.reason);
globalState.followerWebSockets.delete(close.wi);
return;
}
logger().warn("received websocket close for nonexistent follower websocket", {
websocketId: close.wi
});
}
// src/coordinate/node/message-handlers/websocket-leader.ts
import {
handleRawWebSocketHandler,
handleWebSocketConnect,
PATH_CONNECT_WEBSOCKET,
PATH_RAW_WEBSOCKET_PREFIX,
toUint8Array as toUint8Array2
} from "@rivetkit/core";
async function handleLeaderWebSocketOpen(globalState, coordinateDriver, runConfig, actorDriver, nodeId, open) {
if (!nodeId) {
logger().error("node id not provided for leader websocket open");
return;
}
logger().debug("handling leader websocket open", {
nodeId,
websocketId: open.wi,
actorId: open.ai,
url: open.url
});
try {
const actor = await ActorPeer.getLeaderActor(globalState, open.ai);
if (!actor) {
logger().warn("received websocket open for nonexistent actor leader", {
actorId: open.ai
});
return;
}
const url = new URL(`ws://actor${open.url}`);
const path = url.pathname;
const pathWithQuery = url.pathname + url.search;
let wsHandler;
if (path === PATH_CONNECT_WEBSOCKET) {
wsHandler = await handleWebSocketConnect(
void 0,
runConfig,
actorDriver,
open.ai,
open.e,
open.cp,
open.ad
);
} else if (path.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) {
wsHandler = await handleRawWebSocketHandler(
void 0,
pathWithQuery,
actorDriver,
open.ai,
open.ad
);
} else {
throw new Error(`Unreachable path: ${path}`);
}
const fakeWsContext = {
send: (data) => {
const isBinary = data instanceof ArrayBuffer || ArrayBuffer.isView(data);
const encodedData = isBinary ? toUint8Array2(data) : data;
const message = {
b: {
fwm: {
wi: open.wi,
data: encodedData,
binary: isBinary
}
}
};
coordinateDriver.publishToNode(nodeId, message);
},
close: (code, reason) => {
const message = {
b: {
fwc: {
wi: open.wi,
code,
reason
}
}
};
coordinateDriver.publishToNode(nodeId, message);
}
};
globalState.leaderWebSockets = globalState.leaderWebSockets || /* @__PURE__ */ new Map();
globalState.leaderWebSockets.set(open.wi, {
wsHandler,
wsContext: fakeWsContext,
actorId: open.ai
});
logger().debug("sending websocket open confirmation to follower", {
websocketId: open.wi,
nodeId,
actorId: open.ai
});
const openMessage = {
b: {
fwo: {
wi: open.wi
}
}
};
await coordinateDriver.publishToNode(nodeId, openMessage);
logger().debug("websocket open confirmation sent", {
websocketId: open.wi
});
wsHandler.onOpen({}, fakeWsContext);
} catch (error) {
logger().warn("failed to open websocket", { error: `${error}` });
const message = {
b: {
fwc: {
wi: open.wi,
code: 1011,
// Internal error
reason: error instanceof Error ? error.message : "Internal server error"
}
}
};
await coordinateDriver.publishToNode(nodeId, message);
}
}
async function handleLeaderWebSocketMessage(globalState, message) {
var _a;
const wsData = (_a = globalState.leaderWebSockets) == null ? void 0 : _a.get(message.wi);
if (!wsData) {
logger().warn("received websocket message for nonexistent websocket", {
websocketId: message.wi
});
return;
}
const actor = await ActorPeer.getLeaderActor(globalState, wsData.actorId);
if (!actor) {
logger().warn("received websocket message for nonexistent actor leader", {
actorId: wsData.actorId
});
return;
}
const data = message.binary ? message.data instanceof Uint8Array ? message.data : new Uint8Array(
atob(message.data).split("").map((c) => c.charCodeAt(0))
) : message.data;
if (wsData.wsHandler && wsData.wsHandler.onMessage) {
wsData.wsHandler.onMessage({ data }, wsData.wsContext);
}
}
async function handleLeaderWebSocketClose(globalState, close) {
var _a;
const wsData = (_a = globalState.leaderWebSockets) == null ? void 0 : _a.get(close.wi);
if (!wsData) {
logger().warn("received websocket close for nonexistent websocket", {
websocketId: close.wi
});
return;
}
globalState.leaderWebSockets.delete(close.wi);
if (wsData.wsHandler && wsData.wsHandler.onClose) {
wsData.wsHandler.onClose(
{
wasClean: true,
code: close.code ?? 1005,
reason: close.reason ?? ""
},
wsData.wsContext
);
}
}
// src/coordinate/node/proxy-websocket.ts
import { noopNext } from "@rivetkit/core";
import invariant from "invariant";
async function proxyWebSocket(node, c, path, actorId, encoding, connParams, authData) {
var _a, _b;
const upgradeWebSocket = (_b = (_a = node.runConfig).getUpgradeWebSocket) == null ? void 0 : _b.call(_a);
invariant(upgradeWebSocket, "missing getUpgradeWebSocket");
let clientWs;
const relayConn = new RelayConn(
node.registryConfig,
node.runConfig,
node.driverConfig,
node.actorDriver,
node.inlineClient,
node.coordinateDriver,
node.globalState,
{
disconnect: async (reason) => {
clientWs == null ? void 0 : clientWs.close(1e3, reason);
}
},
actorId
);
await relayConn.start();
const websocketId = crypto.randomUUID();
return upgradeWebSocket(() => ({
onOpen: (event, ws) => {
clientWs = ws;
logger().debug("proxy websocket onOpen called", {
websocketId,
actorId,
path
});
node.globalState.followerWebSockets = node.globalState.followerWebSockets || /* @__PURE__ */ new Map();
node.globalState.followerWebSockets.set(websocketId, {
ws,
relayConn
});
const openMessage = {
b: {
lwo: {
ai: actorId,
wi: websocketId,
url: path,
e: encoding,
cp: connParams,
ad: authData
}
}
};
logger().debug("sending websocket open message to leader", {
websocketId,
actorId
});
const _promise = relayConn.publishMessageToleader(openMessage, true);
},
onMessage: (event, ws) => {
const wsData = node.globalState.followerWebSockets.get(websocketId);
if (!wsData) return;
if (event.data instanceof ArrayBuffer) {
const data = new Uint8Array(event.data);
try {
const message = {
b: {
lwm: {
wi: websocketId,
data,
binary: true
}
}
};
const _promise = relayConn.publishMessageToleader(message, false);
} catch (error) {
if (error instanceof LeaderChangedError) {
ws.close(1001, "Actor leader changed");
node.globalState.followerWebSockets.delete(websocketId);
}
}
} else if (event.data instanceof Blob) {
event.data.arrayBuffer().then((arrayBuffer) => {
const data = new Uint8Array(arrayBuffer);
try {
const message = {
b: {
lwm: {
wi: websocketId,
data,
binary: true
}
}
};
const _promise = relayConn.publishMessageToleader(message, false);
} catch (error) {
if (error instanceof LeaderChangedError) {
ws.close(1001, "Actor leader changed");
node.globalState.followerWebSockets.delete(websocketId);
}
}
}).catch((error) => {
logger().error("failed to convert blob to arraybuffer", { error });
});
} else {
try {
const message = {
b: {
lwm: {
wi: websocketId,
data: event.data,
binary: false
}
}
};
const _promise = relayConn.publishMessageToleader(message, false);
} catch (error) {
if (error instanceof LeaderChangedError) {
ws.close(1001, "Actor leader changed");
node.globalState.followerWebSockets.delete(websocketId);
}
}
}
},
onClose: (event, ws) => {
var _a2;
const wsData = (_a2 = node.globalState.followerWebSockets) == null ? void 0 : _a2.get(websocketId);
if (!wsData) return;
const _promise = relayConn.disconnect(false, "Client closed WebSocket", {
b: {
lwc: {
wi: websocketId,
code: event.code,
reason: event.reason
}
}
});
node.globalState.followerWebSockets.delete(websocketId);
}
}))(c, noopNext());
}
// src/coordinate/node/relay-websocket-adapter.ts
import { toUint8Array as toUint8Array3 } from "@rivetkit/core";
var RelayWebSocketAdapter = class {
#node;
#websocketId;
#relayConn;
#readyState = WebSocket.CONNECTING;
#eventListeners = /* @__PURE__ */ new Map();
#onopen = null;
#onclose = null;
#onerror = null;
#onmessage = null;
#bufferedAmount = 0;
#binaryType = "blob";
#extensions = "";
#protocol = "";
#url = "";
#openPromise;
#openResolve;
// Event buffering is needed since events can be fired
// before JavaScript has a chance to add event listeners (e.g. within the same tick)
#bufferedEvents = [];
constructor(node, websocketId, relayConn) {
this.#node = node;
this.#websocketId = websocketId;
this.#relayConn = relayConn;
this.#openPromise = new Promise((resolve) => {
this.#openResolve = resolve;
});
this.#node.globalState.relayWebSockets = this.#node.globalState.relayWebSockets || /* @__PURE__ */ new Map();
this.#node.globalState.relayWebSockets.set(websocketId, this);
logger().debug("relay websocket adapter registered", {
websocketId,
nodeId: this.#node.globalState.nodeId,
relayWebSocketsSize: this.#node.globalState.relayWebSockets.size
});
}
get openPromise() {
return this.#openPromise;
}
get readyState() {
return this.#readyState;
}
get bufferedAmount() {
return this.#bufferedAmount;
}
get binaryType() {
return this.#binaryType;
}
set binaryType(value) {
this.#binaryType = value;
}
get extensions() {
return this.#extensions;
}
get protocol() {
return this.#protocol;
}
get url() {
return this.#url;
}
get actorId() {
return this.#relayConn.actorId;
}
get onopen() {
return this.#onopen;
}
set onopen(value) {
this.#onopen = value;
if (value) {
this.#flushBufferedEvents("open");
}
}
get onclose() {
return this.#onclose;
}
set onclose(value) {
this.#onclose = value;
if (value) {
this.#flushBufferedEvents("close");
}
}
get onerror() {
return this.#onerror;
}
set onerror(value) {
this.#onerror = value;
if (value) {
this.#flushBufferedEvents("error");
}
}
get onmessage() {
return this.#onmessage;
}
set onmessage(value) {
this.#onmessage = value;
if (value) {
this.#flushBufferedEvents("message");
}
}
send(data) {
if (this.#readyState !== WebSocket.OPEN) {
throw new DOMException("WebSocket is not open");
}
let isBinary = false;
let messageData;
if (typeof data === "string") {
messageData = data;
} else if (data instanceof ArrayBuffer || ArrayBuffer.isView(data)) {
isBinary = true;
messageData = toUint8Array3(data);
} else if (data instanceof Blob) {
throw new Error("Blob sending not implemented in relay adapter");
} else {
throw new Error("Invalid data type");
}
const message = {
b: {
lwm: {
wi: this.#websocketId,
data: messageData,
binary: isBinary
}
}
};
this.#relayConn.publishMessageToleader(message, false).catch((error) => {
if (error instanceof LeaderChangedError) {
this._handleClose(1001, "Actor leader changed");
} else {
const event = new Event("error");
this.#fireEvent("error", event);
}
});
}
close(code, reason) {
if (this.#readyState === WebSocket.CLOSING || this.#readyState === WebSocket.CLOSED) {
return;
}
this.#readyState = WebSocket.CLOSING;
this.#relayConn.disconnect(false, "Client closed WebSocket", {
b: {
lwc: {
wi: this.#websocketId,
code,
reason
}
}
}).finally(() => {
var _a;
this.#readyState = WebSocket.CLOSED;
(_a = this.#node.globalState.relayWebSockets) == null ? void 0 : _a.delete(
this.#websocketId
);
const event = {
type: "close",
target: this,
code: code || 1e3,
reason: reason || "",
wasClean: true
};
this.#fireEvent("close", event);
});
}
addEventListener(type, listener, options) {
if (typeof listener === "function") {
let listeners = this.#eventListeners.get(type);
if (!listeners) {
listeners = /* @__PURE__ */ new Set();
this.#eventListeners.set(type, listeners);
}
listeners.add(listener);
logger().debug(`flushing buffered events for ${type}`, {
websocketId: this.#websocketId,
bufferedEventsCount: this.#bufferedEvents.filter((e) => e.type === type).length
});
this.#flushBufferedEvents(type);
}
}
removeEventListener(type, listener, options) {
if (typeof listener === "function") {
const listeners = this.#eventListeners.get(type);
if (listeners) {
listeners.delete(listener);
}
}
}
dispatchEvent(event) {
return true;
}
#fireEvent(type, event) {
const listeners = this.#eventListeners.get(type);
let hasListeners = false;
if (listeners && listeners.size > 0) {
hasListeners = true;
for (const listener of listeners) {
try {
listener.call(this, event);
} catch (error) {
logger().error("error in websocket event listener", { error, type });
}
}
}
switch (type) {
case "open":
if (this.#onopen) {
hasListeners = true;
try {
this.#onopen.call(this, event);
} catch (error) {
logger().error("error in onopen handler", { error });
}
}
break;
case "close":
if (this.#onclose) {
hasListeners = true;
try {
this.#onclose.call(this, event);
} catch (error) {
logger().error("error in onclose handler", { error });
}
}
break;
case "error":
if (this.#onerror) {
hasListeners = true;
try {
this.#onerror.call(this, event);
} catch (error) {
logger().error("error in onerror handler", { error });
}
}
break;
case "message":
if (this.#onmessage) {
hasListeners = true;
try {
this.#onmessage.call(this, event);
} catch (error) {
logger().error("error in onmessage handler", { error });
}
}
break;
}
if (!hasListeners) {
logger().debug(`no ${type} listeners registered, buffering event`);
this.#bufferedEvents.push({ type, event });
}
}
#flushBufferedEvents(type) {
const eventsToFlush = this.#bufferedEvents.filter(
(buffered) => buffered.type === type
);
this.#bufferedEvents = this.#bufferedEvents.filter(
(buffered) => buffered.type !== type
);
for (const { event } of eventsToFlush) {
const listeners = this.#eventListeners.get(type);
if (listeners) {
for (const listener of listeners) {
try {
listener.call(this, event);
} catch (error) {
logger().error("error in websocket event listener", {
error,
type
});
}
}
}
}
}
// Internal method to handle incoming messages from leader
_handleMessage(data, isBinary) {
if (this.#readyState !== WebSocket.OPEN) {
return;
}
let messageData;
if (isBinary) {
if (data instanceof Uint8Array) {
messageData = data;
} else {
throw new Error("Binary data must be Uint8Array");
}
} else {
messageData = data;
}
const event = new MessageEvent("message", {
data: messageData,
origin: "",
lastEventId: ""
});
this.#fireEvent("message", event);
}
// Internal method to handle open confirmation from leader
_handleOpen() {
logger().debug("_handleOpen called", {
websocketId: this.#websocketId,
currentReadyState: this.#readyState,
isConnecting: this.#readyState === WebSocket.CONNECTING
});
if (this.#readyState !== WebSocket.CONNECTING) {
return;
}
this.#readyState = WebSocket.OPEN;
this.#openResolve();
const event = new Event("open");
this.#fireEvent("open", event);
}
// Internal method to handle close from leader
_handleClose(code, reason) {
var _a;
if (this.#readyState === WebSocket.CLOSED) {
return;
}
this.#readyState = WebSocket.CLOSED;
(_a = this.#node.globalState.relayWebSockets) == null ? void 0 : _a.delete(this.#websocketId);
const event = {
type: "close",
target: this,
code: code || 1e3,
reason: reason || "",
wasClean: true
};
this.#fireEvent("close", event);
}
// Required WebSocket constants
static CONNECTING = 0;
static OPEN = 1;
static CLOSING = 2;
static CLOSED = 3;
// Instance constants
CONNECTING = 0;
OPEN = 1;
CLOSING = 2;
CLOSED = 3;
};
// src/coordinate/node/mod.ts
var Node = class {
#registryConfig;
#runConfig;
#driverConfig;
#coordinateDriver;
#globalState;
#inlineClient;
#actorDriver;
#actorRouter;
get inlineClient() {
return this.#inlineClient;
}
get actorDriver() {
return this.#actorDriver;
}
constructor(registryConfig, runConfig, driverConfig, managerDriver, coordinateDriver, globalState, inlineClient, actorDriver, actorRouter) {
this.#registryConfig = registryConfig;
this.#runConfig = runConfig;
this.#driverConfig = driverConfig;
this.#coordinateDriver = coordinateDriver;
this.#globalState = globalState;
this.#inlineClient = inlineClient;
this.#actorDriver = actorDriver;
this.#actorRouter = actorRouter;
}
get globalState() {
return this.#globalState;
}
get coordinateDriver() {
return this.#coordinateDriver;
}
get registryConfig() {
return this.#registryConfig;
}
get runConfig() {
return this.#runConfig;
}
get driverConfig() {
return this.#driverConfig;
}
async start() {
logger().debug("starting", { nodeId: this.#globalState.nodeId });
await this.#coordinateDriver.createNodeSubscriber(
this.#globalState.nodeId,
this.#onMessage.bind(this)
);
logger().debug("node started", { nodeId: this.#globalState.nodeId });
}
async #onMessage(data) {
const shouldAck = !!(data.n && data.m);
logger().debug("node received message", { data, shouldAck });
if (shouldAck) {
invariant2(data.n && data.m, "unreachable");
if ("a" in data.b) {
throw new Error("Ack messages cannot request ack in response");
}
const messageRaw = {
b: {
a: {
m: data.m
}
}
};
this.#coordinateDriver.publishToNode(data.n, messageRaw);
}
if ("a" in data.b) {
await this.#onAck(data.b.a);
} else if ("lf" in data.b) {
await handleLeaderFetch(
this.#globalState,
this.#coordinateDriver,
this.#actorRouter,
data.n,
data.b.lf
);
} else if ("ffr" in data.b) {
handleFollowerFetchResponse(this.#globalState, data.b.ffr);
} else if ("lwo" in data.b) {
logger().debug("received lwo (leader websocket open) message", {
websocketId: data.b.lwo.wi,
actorId: data.b.lwo.ai,
fromNodeId: data.n
});
await handleLeaderWebSocketOpen(
this.#globalState,
this.#coordinateDriver,
this.#runConfig,
this.#actorDriver,
data.n,
data.b.lwo
);
} else if ("lwm" in data.b) {
await handleLeaderWebSocketMessage(this.#globalState, data.b.lwm);
} else if ("lwc" in data.b) {
await handleLeaderWebSocketClose(this.#globalState, data.b.lwc);
} else if ("fwo" in data.b) {
logger().debug("received fwo (follower websocket open) message", {
websocketId: data.b.fwo.wi
});
await handleFollowerWebSocketOpen(this.#globalState, data.b.fwo);
} else if ("fwm" in data.b) {
await handleFollowerWebSocketMessage(this.#globalState, data.b.fwm);
} else if ("fwc" in data.b) {
await handleFollowerWebSocketClose(this.#globalState, data.b.fwc);
} else {
assertUnreachable(data.b);
}
}
async #onAck({ m: messageId }) {
const resolveAck = this.#globalState.messageAckResolvers.get(messageId);
if (resolveAck) {
resolveAck();
this.#globalState.messageAckResolvers.delete(messageId);
} else {
logger().warn("missing ack resolver", { messageId });
}
}
async sendRequest(actorId, actorRequest, abortController) {
const requestId = crypto.randomUUID();
const url = new URL(actorRequest.url);
const headers = {};
actorRequest.headers.forEach((value, key) => {
headers[key] = value;
});
let body;
if (actorRequest.body) {
const buffer = await actorRequest.arrayBuffer();
body = new Uint8Array(buffer);
}
const responsePromise = new Promise((resolve) => {
this.#globalState.fetchResponseResolvers.set(requestId, resolve);
});
const relayConn = new RelayConn(
this.#registryConfig,
this.#runConfig,
this.#driverConfig,
this.#actorDriver,
this.#inlineClient,
this.#coordinateDriver,
this.#globalState,
{
disconnect: async (_reason) => {
}
},
actorId
);
await relayConn.start();
try {
const message = {
b: {
lf: {
ri: requestId,
ai: actorId,
method: actorRequest.method,
url: url.pathname + url.search,
headers,
body,
// TODO: Auth data
ad: void 0
}
}
};
await relayConn.publishMessageToleader(message, true);
} catch (error) {
this.#globalState.fetchResponseResolvers.delete(requestId);
if (error instanceof Error) {
return new Response(error.message, { status: 503 });
}
return new Response(
"Service unavailable (cannot send message to actor leader)",
{ status: 503 }
);
}
const response = await responsePromise.finally(() => {
this.#globalState.fetchResponseResolvers.delete(requestId);
});
if (response.error) {
return new Response(response.error, {
status: response.status,
headers: response.headers
});
}
const responseBody = response.body;
return new Response(responseBody, {
status: response.status,
headers: response.headers
});
}
// TODO: Clean up disconnecting logic for websocket. There might be missed edge conditions depending on if client or server terminates the websocket
async openWebSocket(path, actorId, encoding, connParams) {
const websocketId = crypto.randomUUID();
logger().debug("opening websocket for inline client", {
websocketId,
actorId,
path,
encoding,
nodeId: this.#globalState.nodeId
});
const relayConn = new RelayConn(
this.#registryConfig,
this.#runConfig,
this.#driverConfig,
this.#actorDriver,
this.#inlineClient,
this.#coordinateDriver,
this.#globalState,
{
disconnect: async (_reason) => {
}
},
actorId
);
await relayConn.start();
const adapter = new RelayWebSocketAdapter(this, websocketId, relayConn);
this.#globalState.relayWebSockets.set(websocketId, adapter);
const openMessage = {
b: {
lwo: {
ai: actorId,
wi: websocketId,
url: path,
e: encoding,
cp: connParams,
ad: void 0
}
}
};
await relayConn.publishMessageToleader(openMessage, true);
logger().debug("websocket adapter created, waiting for open", {
websocketId
});
logger().debug("waiting for websocket adapter open promise", {
websocketId,
actorId,
path,
encoding,
adapterReadyState: adapter.readyState
});
await adapter.openPromise;
logger().debug("websocket adapter open promise resolved", {
websocketId,
actorId,
adapterReadyState: adapter.readyState
});
logger().debug("websocket adapter ready", { websocketId });
return adapter;
}
// TODO: Implement abort controller
async proxyRequest(c, actorRequest, actorId) {
return await this.sendRequest(actorId, actorRequest);
}
async proxyWebSocket(c, path, actorId, encoding, connParams, authData) {
return proxyWebSocket(
this,
c,
path,
actorId,
encoding,
connParams,
authData
);
}
};
// src/mod.ts
function createRedisDriver(options) {
var _a, _b, _c, _d, _e;
const driverConfig = RedisDriverConfig.parse({
...options,
actorPeer: {
...options == null ? void 0 : options.actorPeer,
leaseDuration: ((_a = options == null ? void 0 : options.actorPeer) == null ? void 0 : _a.leaseDuration) ?? 3e3,
renewLeaseGrace: ((_b = options == null ? void 0 : options.actorPeer) == null ? void 0 : _b.renewLeaseGrace) ?? 1500,
checkLeaseInterval: ((_c = options == null ? void 0 : options.actorPeer) == null ? void 0 : _c.checkLeaseInterval) ?? 1e3,
checkLeaseJitter: ((_d = options == null ? void 0 : options.actorPeer) == null ? void 0 : _d.checkLeaseJitter) ?? 500,
messageAckTimeout: ((_e = options == null ? void 0 : options.actorPeer) == null ? void 0 : _e.messageAckTimeout) ?? 1e3
}
});
const globalState = {
nodeId: crypto.randomUUID(),
actorPeers: /* @__PURE__ */ new Map(),
relayConns: /* @__PURE__ */ new Map(),
messageAckResolvers: /* @__PURE__ */ new Map(),
actionResponseResolvers: /* @__PURE__ */ new Map(),
fetchResponseResolvers: /* @__PURE__ */ new Map(),
rawWebSockets: /* @__PURE__ */ new Map(),
followerWebSockets: /* @__PURE__ */ new Map(),
relayWebSockets: /* @__PURE__ */ new Map()
};
const coordinate = new RedisCoordinateDriver(
driverConfig,
driverConfig.redis
);
return {
name: "redis",
manager: (registryConfig, runConfig) => {
const manager = new RedisManagerDriver(
registryConfig,
driverConfig,
driverConfig.redis
);
const inlineClient = createClientWithDriver(
createInlineClientDriver(manager)
);
const actorDriver = new RedisActorDriver(
globalState,
driverConfig.redis,
driverConfig
);
const actorRouter = createActorRouter(runConfig, actorDriver);
const node = new Node(
registryConfig,
runConfig,
driverConfig,
manager,
coordinate,
globalState,
inlineClient,
actorDriver,
actorRouter
);
manager.node = node;
node.start();
return manager;
},
actor: (registryConfig, runConfig, managerDriver, inlineClient) => {
return new RedisActorDriver(
globalState,
driverConfig.redis,
driverConfig
);
}
};
}
export {
RedisActorDriver,
RedisDriverConfig,
RedisManagerDriver,
createRedisDriver
};
//# sourceMappingURL=mod.js.map