UNPKG

@rivetkit/redis

Version:

_Lightweight Libraries for Backends_

417 lines (377 loc) 10.7 kB
import { ActionContext, type ActorRouter, type AnyClient, createActorRouter, createClientWithDriver, createInlineClientDriver, type Encoding, generateConnId, generateConnToken, handleRawWebSocketHandler, handleWebSocketConnect, InlineWebSocketAdapter2, noopNext, PATH_CONNECT_WEBSOCKET, PATH_RAW_WEBSOCKET_PREFIX, type RegistryConfig, RivetCloseEvent, type RunConfig, UniversalWebSocket, } from "@rivetkit/core"; import type { ActorDriver, ManagerDriver } from "@rivetkit/core/driver-helpers"; import { assertUnreachable } from "@rivetkit/core/utils"; import * as cbor from "cbor-x"; import type { Context as HonoContext } from "hono"; import invariant from "invariant"; import { ActorPeer } from "../actor-peer"; import type { CoordinateDriverConfig } from "../config"; import type { CoordinateDriver } from "../driver"; import { logger } from "../log"; import { RelayConn } from "../relay-conn"; import type { GlobalState } from "../types"; import { LeaderChangedError, publishMessageToLeader, publishMessageToLeaderNoRetry, } from "./message"; import { handleFollowerFetchResponse, handleLeaderFetch, } from "./message-handlers/fetch"; import { handleFollowerWebSocketClose, handleFollowerWebSocketMessage, handleFollowerWebSocketOpen, } from "./message-handlers/websocket-follower"; import { handleLeaderWebSocketClose, handleLeaderWebSocketMessage, handleLeaderWebSocketOpen, } from "./message-handlers/websocket-leader"; import { type Ack, type NodeMessage, NodeMessageSchema, type ToFollowerFetchResponse, } from "./protocol"; import { proxyWebSocket as proxyWebSocketImpl } from "./proxy-websocket"; import { RelayWebSocketAdapter } from "./relay-websocket-adapter"; export class Node { #registryConfig: RegistryConfig; #runConfig: RunConfig; #driverConfig: CoordinateDriverConfig; #coordinateDriver: CoordinateDriver; #globalState: GlobalState; #inlineClient: AnyClient; #actorDriver: ActorDriver; #actorRouter: ActorRouter; get inlineClient() { return this.#inlineClient; } get actorDriver() { return this.#actorDriver; } constructor( registryConfig: RegistryConfig, runConfig: RunConfig, driverConfig: CoordinateDriverConfig, managerDriver: ManagerDriver, coordinateDriver: CoordinateDriver, globalState: GlobalState, inlineClient: AnyClient, actorDriver: ActorDriver, actorRouter: ActorRouter, ) { this.#registryConfig = registryConfig; this.#runConfig = runConfig; this.#driverConfig = driverConfig; this.#coordinateDriver = coordinateDriver; this.#globalState = globalState; this.#inlineClient = inlineClient; this.#actorDriver = actorDriver; this.#actorRouter = actorRouter; } get globalState(): GlobalState { return this.#globalState; } get coordinateDriver(): CoordinateDriver { return this.#coordinateDriver; } get registryConfig(): RegistryConfig { return this.#registryConfig; } get runConfig(): RunConfig { return this.#runConfig; } get driverConfig(): CoordinateDriverConfig { return this.#driverConfig; } async start() { logger().debug("starting", { nodeId: this.#globalState.nodeId }); // Subscribe to events // // We intentionally design this so there's only one topic for the subscriber to listen on in order to reduce chattiness to the pubsub server. // // If we had a dedicated topic for each actor, we'd have to create a SUB for each leader & follower for each actor which is much more expensive than one for each node. // // Additionally, in most serverless environments, 1 node usually owns 1 actor, so this would double the RTT to create the required subscribers. await this.#coordinateDriver.createNodeSubscriber( this.#globalState.nodeId, this.#onMessage.bind(this), ); logger().debug("node started", { nodeId: this.#globalState.nodeId }); } async #onMessage(data: NodeMessage) { const shouldAck = !!(data.n && data.m); logger().debug("node received message", { data, shouldAck }); // Ack message if (shouldAck) { invariant(data.n && data.m, "unreachable"); if ("a" in data.b) { throw new Error("Ack messages cannot request ack in response"); } const messageRaw: NodeMessage = { b: { a: { m: data.m, }, }, }; this.#coordinateDriver.publishToNode(data.n, messageRaw); } // Handle message if ("a" in data.b) { await this.#onAck(data.b.a); } else if ("lf" in data.b) { await handleLeaderFetch( this.#globalState, this.#coordinateDriver, this.#actorRouter, data.n, data.b.lf, ); } else if ("ffr" in data.b) { handleFollowerFetchResponse(this.#globalState, data.b.ffr); } else if ("lwo" in data.b) { logger().debug("received lwo (leader websocket open) message", { websocketId: data.b.lwo.wi, actorId: data.b.lwo.ai, fromNodeId: data.n, }); await handleLeaderWebSocketOpen( this.#globalState, this.#coordinateDriver, this.#runConfig, this.#actorDriver, data.n, data.b.lwo, ); } else if ("lwm" in data.b) { await handleLeaderWebSocketMessage(this.#globalState, data.b.lwm); } else if ("lwc" in data.b) { await handleLeaderWebSocketClose(this.#globalState, data.b.lwc); } else if ("fwo" in data.b) { logger().debug("received fwo (follower websocket open) message", { websocketId: data.b.fwo.wi, }); await handleFollowerWebSocketOpen(this.#globalState, data.b.fwo); } else if ("fwm" in data.b) { await handleFollowerWebSocketMessage(this.#globalState, data.b.fwm); } else if ("fwc" in data.b) { await handleFollowerWebSocketClose(this.#globalState, data.b.fwc); } else { assertUnreachable(data.b); } } async #onAck({ m: messageId }: Ack) { const resolveAck = this.#globalState.messageAckResolvers.get(messageId); if (resolveAck) { resolveAck(); this.#globalState.messageAckResolvers.delete(messageId); } else { logger().warn("missing ack resolver", { messageId }); } } async sendRequest( actorId: string, actorRequest: Request, abortController?: AbortController, ): Promise<Response> { // Generate request ID const requestId = crypto.randomUUID(); // Extract request details const url = new URL(actorRequest.url); const headers: Record<string, string> = {}; actorRequest.headers.forEach((value, key) => { headers[key] = value; }); let body: Uint8Array | undefined; if (actorRequest.body) { const buffer = await actorRequest.arrayBuffer(); body = new Uint8Array(buffer); } // Create promise to wait for response const responsePromise = new Promise<ToFollowerFetchResponse>((resolve) => { this.#globalState.fetchResponseResolvers.set(requestId, resolve); }); // Open connection const relayConn = new RelayConn( this.#registryConfig, this.#runConfig, this.#driverConfig, this.#actorDriver, this.#inlineClient, this.#coordinateDriver, this.#globalState, { disconnect: async (_reason: any) => { // TODO: Abort request client-side }, }, actorId, ); await relayConn.start(); // Publish request try { const message: NodeMessage = { b: { lf: { ri: requestId, ai: actorId, method: actorRequest.method, url: url.pathname + url.search, headers, body, // TODO: Auth data ad: undefined, }, }, }; await relayConn.publishMessageToleader(message, true); } catch (error) { this.#globalState.fetchResponseResolvers.delete(requestId); if (error instanceof Error) { return new Response(error.message, { status: 503 }); } return new Response( "Service unavailable (cannot send message to actor leader)", { status: 503 }, ); } // Wait for response with timeout (publishMessageToLeader already handles leader retries) const response = await responsePromise.finally(() => { this.#globalState.fetchResponseResolvers.delete(requestId); }); // Handle error response if (response.error) { return new Response(response.error, { status: response.status, headers: response.headers, }); } // Reconstruct response const responseBody = response.body; return new Response(responseBody, { status: response.status, headers: response.headers, }); } // TODO: Clean up disconnecting logic for websocket. There might be missed edge conditions depending on if client or server terminates the websocket async openWebSocket( path: string, actorId: string, encoding: Encoding, connParams: unknown, ): Promise<WebSocket> { // Create WebSocket ID const websocketId = crypto.randomUUID(); logger().debug("opening websocket for inline client", { websocketId, actorId, path, encoding, nodeId: this.#globalState.nodeId, }); // Open connection const relayConn = new RelayConn( this.#registryConfig, this.#runConfig, this.#driverConfig, this.#actorDriver, this.#inlineClient, this.#coordinateDriver, this.#globalState, { disconnect: async (_reason: any) => { // TODO: Abort request client-side }, }, actorId, ); await relayConn.start(); // Create a WebSocket adapter that relays messages BEFORE sending the open message // This ensures the adapter is registered when the open confirmation arrives const adapter = new RelayWebSocketAdapter(this, websocketId, relayConn); this.#globalState.relayWebSockets.set(websocketId, adapter); // Open WebSocket const openMessage: NodeMessage = { b: { lwo: { ai: actorId, wi: websocketId, url: path, e: encoding, cp: connParams, ad: undefined, }, }, }; await relayConn.publishMessageToleader(openMessage, true); logger().debug("websocket adapter created, waiting for open", { websocketId, }); // Wait for the WebSocket to be open before returning logger().debug("waiting for websocket adapter open promise", { websocketId, actorId, path, encoding, adapterReadyState: adapter.readyState, }); await adapter.openPromise; logger().debug("websocket adapter open promise resolved", { websocketId, actorId, adapterReadyState: adapter.readyState, }); logger().debug("websocket adapter ready", { websocketId }); return adapter; } // TODO: Implement abort controller async proxyRequest( c: HonoContext, actorRequest: Request, actorId: string, ): Promise<Response> { return await this.sendRequest(actorId, actorRequest); } async proxyWebSocket( c: HonoContext, path: string, actorId: string, encoding: Encoding, connParams: unknown, authData: unknown, ): Promise<Response> { return proxyWebSocketImpl( this, c, path, actorId, encoding, connParams, authData, ); } }