UNPKG

@rivetkit/core

Version:

810 lines (710 loc) 22.1 kB
import * as cbor from "cbor-x"; import pRetry from "p-retry"; import type { CloseEvent, WebSocket } from "ws"; import type { AnyActorDefinition } from "@/actor/definition"; import type * as wsToClient from "@/actor/protocol/message/to-client"; import type * as wsToServer from "@/actor/protocol/message/to-server"; import type { Encoding } from "@/actor/protocol/serde"; import type { UniversalErrorEvent, UniversalEventSource, UniversalMessageEvent, } from "@/common/eventsource-interface"; import { assertUnreachable, stringifyError } from "@/common/utils"; import type { ActorQuery } from "@/manager/protocol/query"; import type { ActorDefinitionActions } from "./actor-common"; import { ACTOR_CONNS_SYMBOL, type ClientDriver, type ClientRaw, TRANSPORT_SYMBOL, } from "./client"; import * as errors from "./errors"; import { logger } from "./log"; import { rawHttpFetch, rawWebSocket } from "./raw-utils"; import { type WebSocketMessage as ConnMessage, messageLength, serializeWithEncoding, } from "./utils"; interface ActionInFlight { name: string; resolve: (response: wsToClient.ActionResponse) => void; reject: (error: Error) => void; } interface EventSubscriptions<Args extends Array<unknown>> { callback: (...args: Args) => void; once: boolean; } /** * A function that unsubscribes from an event. * * @typedef {Function} EventUnsubscribe */ export type EventUnsubscribe = () => void; /** * A function that handles connection errors. * * @typedef {Function} ActorErrorCallback */ export type ActorErrorCallback = (error: errors.ActorError) => void; export interface SendHttpMessageOpts { ephemeral: boolean; signal?: AbortSignal; } export type ConnTransport = | { websocket: WebSocket } | { sse: UniversalEventSource }; export const CONNECT_SYMBOL = Symbol("connect"); /** * Provides underlying functions for {@link ActorConn}. See {@link ActorConn} for using type-safe remote procedure calls. * * @see {@link ActorConn} */ export class ActorConnRaw { #disposed = false; /* Will be aborted on dispose. */ #abortController = new AbortController(); /** If attempting to connect. Helpful for knowing if in a retry loop when reconnecting. */ #connecting = false; // These will only be set on SSE driver #actorId?: string; #connectionId?: string; #connectionToken?: string; #transport?: ConnTransport; #messageQueue: wsToServer.ToServer[] = []; #actionsInFlight = new Map<number, ActionInFlight>(); // biome-ignore lint/suspicious/noExplicitAny: Unknown subscription type #eventSubscriptions = new Map<string, Set<EventSubscriptions<any[]>>>(); #errorHandlers = new Set<ActorErrorCallback>(); #actionIdCounter = 0; /** * Interval that keeps the NodeJS process alive if this is the only thing running. * * See ttps://github.com/nodejs/node/issues/22088 */ #keepNodeAliveInterval: NodeJS.Timeout; /** Promise used to indicate the socket has connected successfully. This will be rejected if the connection fails. */ #onOpenPromise?: PromiseWithResolvers<undefined>; #client: ClientRaw; #driver: ClientDriver; #params: unknown; #encodingKind: Encoding; #actorQuery: ActorQuery; // TODO: ws message queue /** * Do not call this directly. * * Creates an instance of ActorConnRaw. * * @protected */ public constructor( private client: ClientRaw, private driver: ClientDriver, private params: unknown, private encodingKind: Encoding, private actorQuery: ActorQuery, ) { this.#client = client; this.#driver = driver; this.#params = params; this.#encodingKind = encodingKind; this.#actorQuery = actorQuery; this.#keepNodeAliveInterval = setInterval(() => 60_000); } /** * Call a raw action connection. See {@link ActorConn} for type-safe action calls. * * @see {@link ActorConn} * @template Args - The type of arguments to pass to the action function. * @template Response - The type of the response returned by the action function. * @param {string} name - The name of the action function to call. * @param {...Args} args - The arguments to pass to the action function. * @returns {Promise<Response>} - A promise that resolves to the response of the action function. */ async action< Args extends Array<unknown> = unknown[], Response = unknown, >(opts: { name: string; args: Args; signal?: AbortSignal; }): Promise<Response> { logger().debug("action", { name: opts.name, args: opts.args }); // If we have an active connection, use the websockactionId const actionId = this.#actionIdCounter; this.#actionIdCounter += 1; const { promise, resolve, reject } = Promise.withResolvers<wsToClient.ActionResponse>(); this.#actionsInFlight.set(actionId, { name: opts.name, resolve, reject }); this.#sendMessage({ b: { ar: { i: actionId, n: opts.name, a: opts.args, }, }, } satisfies wsToServer.ToServer); // TODO: Throw error if disconnect is called const { i: responseId, o: output } = await promise; if (responseId !== actionId) throw new Error( `Request ID ${actionId} does not match response ID ${responseId}`, ); return output as Response; } /** * Do not call this directly. enc * Establishes a connection to the server using the specified endpoint & encoding & driver. * * @protected */ public [CONNECT_SYMBOL]() { this.#connectWithRetry(); } async #connectWithRetry() { this.#connecting = true; // Attempt to reconnect indefinitely try { await pRetry(this.#connectAndWait.bind(this), { forever: true, minTimeout: 250, maxTimeout: 30_000, onFailedAttempt: (error) => { logger().warn("failed to reconnect", { attempt: error.attemptNumber, error: stringifyError(error), }); }, // Cancel retry if aborted signal: this.#abortController.signal, }); } catch (err) { if ((err as Error).name === "AbortError") { // Ignore abortions logger().info("connection retry aborted"); return; } else { // Unknown error throw err; } } this.#connecting = false; } async #connectAndWait() { try { // Create promise for open if (this.#onOpenPromise) throw new Error("#onOpenPromise already defined"); this.#onOpenPromise = Promise.withResolvers(); // Connect transport if (this.#client[TRANSPORT_SYMBOL] === "websocket") { await this.#connectWebSocket(); } else if (this.#client[TRANSPORT_SYMBOL] === "sse") { await this.#connectSse(); } else { assertUnreachable(this.#client[TRANSPORT_SYMBOL]); } // Wait for result await this.#onOpenPromise.promise; } finally { this.#onOpenPromise = undefined; } } async #connectWebSocket({ signal }: { signal?: AbortSignal } = {}) { const ws = await this.#driver.connectWebSocket( undefined, this.#actorQuery, this.#encodingKind, this.#params, signal ? { signal } : undefined, ); this.#transport = { websocket: ws }; ws.addEventListener("open", () => { logger().debug("websocket open"); }); ws.addEventListener("message", async (ev) => { this.#handleOnMessage(ev.data); }); ws.addEventListener("close", (ev) => { this.#handleOnClose(ev); }); ws.addEventListener("error", (ev) => { this.#handleOnError(); }); } async #connectSse({ signal }: { signal?: AbortSignal } = {}) { const eventSource = await this.#driver.connectSse( undefined, this.#actorQuery, this.#encodingKind, this.#params, signal ? { signal } : undefined, ); this.#transport = { sse: eventSource }; eventSource.onopen = () => { logger().debug("eventsource open"); // #handleOnOpen is called on "i" event }; eventSource.onmessage = (ev: UniversalMessageEvent) => { this.#handleOnMessage(ev.data); }; eventSource.onerror = (ev: UniversalErrorEvent) => { if (eventSource.readyState === eventSource.CLOSED) { // This error indicates a close event this.#handleOnClose(new Event("error")); } else { // Log error since event source is still open this.#handleOnError(); } }; } /** Called by the onopen event from drivers. */ #handleOnOpen() { logger().debug("socket open", { messageQueueLength: this.#messageQueue.length, }); // Resolve open promise if (this.#onOpenPromise) { this.#onOpenPromise.resolve(undefined); } else { logger().warn("#onOpenPromise is undefined"); } // Resubscribe to all active events for (const eventName of this.#eventSubscriptions.keys()) { this.#sendSubscription(eventName, true); } // Flush queue // // If the message fails to send, the message will be re-queued const queue = this.#messageQueue; this.#messageQueue = []; for (const msg of queue) { this.#sendMessage(msg); } } /** Called by the onmessage event from drivers. */ async #handleOnMessage(data: any) { logger().trace("received message", { dataType: typeof data, isBlob: data instanceof Blob, isArrayBuffer: data instanceof ArrayBuffer, }); const response = (await this.#parse( data as ConnMessage, )) as wsToClient.ToClient; logger().trace("parsed message", { response: JSON.stringify(response).substring(0, 100) + "...", }); if ("i" in response.b) { // This is only called for SSE this.#actorId = response.b.i.ai; this.#connectionId = response.b.i.ci; this.#connectionToken = response.b.i.ct; logger().trace("received init message", { actorId: this.#actorId, connectionId: this.#connectionId, }); this.#handleOnOpen(); } else if ("e" in response.b) { // Connection error const { c: code, m: message, md: metadata, ai: actionId } = response.b.e; if (actionId) { const inFlight = this.#takeActionInFlight(actionId); logger().warn("action error", { actionId: actionId, actionName: inFlight?.name, code, message, metadata, }); inFlight.reject(new errors.ActorError(code, message, metadata)); } else { logger().warn("connection error", { code, message, metadata, }); // Create a connection error const actorError = new errors.ActorError(code, message, metadata); // If we have an onOpenPromise, reject it with the error if (this.#onOpenPromise) { this.#onOpenPromise.reject(actorError); } // Reject any in-flight requests for (const [id, inFlight] of this.#actionsInFlight.entries()) { inFlight.reject(actorError); this.#actionsInFlight.delete(id); } // Dispatch to error handler if registered this.#dispatchActorError(actorError); } } else if ("ar" in response.b) { // Action response OK const { i: actionId, o: outputType } = response.b.ar; logger().trace("received action response", { actionId, outputType, }); const inFlight = this.#takeActionInFlight(actionId); logger().trace("resolving action promise", { actionId, actionName: inFlight?.name, }); inFlight.resolve(response.b.ar); } else if ("ev" in response.b) { logger().trace("received event", { name: response.b.ev.n, argsCount: response.b.ev.a?.length, }); this.#dispatchEvent(response.b.ev); } else { assertUnreachable(response.b); } } /** Called by the onclose event from drivers. */ #handleOnClose(event: Event | CloseEvent) { // TODO: Handle queue // TODO: Reconnect with backoff // Reject open promise if (this.#onOpenPromise) { this.#onOpenPromise.reject(new Error("Closed")); } // We can't use `event instanceof CloseEvent` because it's not defined in NodeJS // // These properties will be undefined const closeEvent = event as CloseEvent; if (closeEvent.wasClean) { logger().info("socket closed", { code: closeEvent.code, reason: closeEvent.reason, wasClean: closeEvent.wasClean, }); } else { logger().warn("socket closed", { code: closeEvent.code, reason: closeEvent.reason, wasClean: closeEvent.wasClean, }); } this.#transport = undefined; // Automatically reconnect. Skip if already attempting to connect. if (!this.#disposed && !this.#connecting) { // TODO: Fetch actor to check if it's destroyed // TODO: Add backoff for reconnect // TODO: Add a way of preserving connection ID for connection state // Attempt to connect again this.#connectWithRetry(); } } /** Called by the onerror event from drivers. */ #handleOnError() { if (this.#disposed) return; // More detailed information will be logged in onclose logger().warn("socket error"); } #takeActionInFlight(id: number): ActionInFlight { const inFlight = this.#actionsInFlight.get(id); if (!inFlight) { throw new errors.InternalError(`No in flight response for ${id}`); } this.#actionsInFlight.delete(id); return inFlight; } #dispatchEvent(event: wsToClient.Event) { const { n: name, a: args } = event; const listeners = this.#eventSubscriptions.get(name); if (!listeners) return; // Create a new array to avoid issues with listeners being removed during iteration for (const listener of [...listeners]) { listener.callback(...args); // Remove if this was a one-time listener if (listener.once) { listeners.delete(listener); } } // Clean up empty listener sets if (listeners.size === 0) { this.#eventSubscriptions.delete(name); } } #dispatchActorError(error: errors.ActorError) { // Call all registered error handlers for (const handler of [...this.#errorHandlers]) { try { handler(error); } catch (err) { logger().error("Error in connection error handler", { error: stringifyError(err), }); } } } #addEventSubscription<Args extends Array<unknown>>( eventName: string, callback: (...args: Args) => void, once: boolean, ): EventUnsubscribe { const listener: EventSubscriptions<Args> = { callback, once, }; let subscriptionSet = this.#eventSubscriptions.get(eventName); if (subscriptionSet === undefined) { subscriptionSet = new Set(); this.#eventSubscriptions.set(eventName, subscriptionSet); this.#sendSubscription(eventName, true); } subscriptionSet.add(listener); // Return unsubscribe function return () => { const listeners = this.#eventSubscriptions.get(eventName); if (listeners) { listeners.delete(listener); if (listeners.size === 0) { this.#eventSubscriptions.delete(eventName); this.#sendSubscription(eventName, false); } } }; } /** * Subscribes to an event that will happen repeatedly. * * @template Args - The type of arguments the event callback will receive. * @param {string} eventName - The name of the event to subscribe to. * @param {(...args: Args) => void} callback - The callback function to execute when the event is triggered. * @returns {EventUnsubscribe} - A function to unsubscribe from the event. * @see {@link https://rivet.gg/docs/events|Events Documentation} */ on<Args extends Array<unknown> = unknown[]>( eventName: string, callback: (...args: Args) => void, ): EventUnsubscribe { return this.#addEventSubscription<Args>(eventName, callback, false); } /** * Subscribes to an event that will be triggered only once. * * @template Args - The type of arguments the event callback will receive. * @param {string} eventName - The name of the event to subscribe to. * @param {(...args: Args) => void} callback - The callback function to execute when the event is triggered. * @returns {EventUnsubscribe} - A function to unsubscribe from the event. * @see {@link https://rivet.gg/docs/events|Events Documentation} */ once<Args extends Array<unknown> = unknown[]>( eventName: string, callback: (...args: Args) => void, ): EventUnsubscribe { return this.#addEventSubscription<Args>(eventName, callback, true); } /** * Subscribes to connection errors. * * @param {ActorErrorCallback} callback - The callback function to execute when a connection error occurs. * @returns {() => void} - A function to unsubscribe from the error handler. */ onError(callback: ActorErrorCallback): () => void { this.#errorHandlers.add(callback); // Return unsubscribe function return () => { this.#errorHandlers.delete(callback); }; } #sendMessage(message: wsToServer.ToServer, opts?: SendHttpMessageOpts) { if (this.#disposed) { throw new errors.ActorConnDisposed(); } let queueMessage = false; if (!this.#transport) { // No transport connected yet queueMessage = true; } else if ("websocket" in this.#transport) { if (this.#transport.websocket.readyState === 1) { try { const messageSerialized = serializeWithEncoding( this.#encodingKind, message, ); this.#transport.websocket.send(messageSerialized); logger().trace("sent websocket message", { len: messageLength(messageSerialized), }); } catch (error) { logger().warn("failed to send message, added to queue", { error, }); // Assuming the socket is disconnected and will be reconnected soon queueMessage = true; } } else { queueMessage = true; } } else if ("sse" in this.#transport) { if (this.#transport.sse.readyState === 1) { // Spawn in background since #sendMessage cannot be async this.#sendHttpMessage(message, opts); } else { queueMessage = true; } } else { assertUnreachable(this.#transport); } if (!opts?.ephemeral && queueMessage) { this.#messageQueue.push(message); logger().debug("queued connection message"); } } async #sendHttpMessage( message: wsToServer.ToServer, opts?: SendHttpMessageOpts, ) { try { if (!this.#actorId || !this.#connectionId || !this.#connectionToken) throw new errors.InternalError("Missing connection ID or token."); logger().trace("sent http message", { message: JSON.stringify(message).substring(0, 100) + "...", }); const res = await this.#driver.sendHttpMessage( undefined, this.#actorId, this.#encodingKind, this.#connectionId, this.#connectionToken, message, opts?.signal ? { signal: opts.signal } : undefined, ); if (!res.ok) { throw new errors.InternalError( `Publish message over HTTP error (${res.statusText}):\n${await res.text()}`, ); } // Dispose of the response body, we don't care about it await res.json(); } catch (error) { // TODO: This will not automatically trigger a re-broadcast of HTTP events since SSE is separate from the HTTP action logger().warn("failed to send message, added to queue", { error, }); // Assuming the socket is disconnected and will be reconnected soon // // Will attempt to resend soon if (!opts?.ephemeral) { this.#messageQueue.unshift(message); } } } async #parse(data: ConnMessage): Promise<unknown> { if (this.#encodingKind === "json") { if (typeof data !== "string") { throw new Error("received non-string for json parse"); } return JSON.parse(data); } else if (this.#encodingKind === "cbor") { if (!this.#transport) { // Do thing throw new Error("Cannot parse message when no transport defined"); } else if ("sse" in this.#transport) { // Decode base64 since SSE sends raw strings if (typeof data === "string") { const binaryString = atob(data); data = new Uint8Array( [...binaryString].map((char) => char.charCodeAt(0)), ); } else { throw new errors.InternalError( `Expected data to be a string for SSE, got ${data}.`, ); } } else if ("websocket" in this.#transport) { // Do nothing } else { assertUnreachable(this.#transport); } // Decode data if (data instanceof Blob) { return cbor.decode(new Uint8Array(await data.arrayBuffer())); } else if (data instanceof ArrayBuffer) { return cbor.decode(new Uint8Array(data)); } else if (data instanceof Uint8Array) { return cbor.decode(data); } else { throw new Error( `received non-binary type for cbor parse: ${typeof data}`, ); } } else { assertUnreachable(this.#encodingKind); } } /** * Disconnects from the actor. * * @returns {Promise<void>} A promise that resolves when the socket is gracefully closed. */ async dispose(): Promise<void> { // Internally, this "disposes" the connection if (this.#disposed) { logger().warn("connection already disconnected"); return; } this.#disposed = true; logger().debug("disposing actor"); // Clear interval so NodeJS process can exit clearInterval(this.#keepNodeAliveInterval); // Abort this.#abortController.abort(); // Remove from registry this.#client[ACTOR_CONNS_SYMBOL].delete(this); // Disconnect transport cleanly if (!this.#transport) { // Nothing to do } else if ("websocket" in this.#transport) { const { promise, resolve } = Promise.withResolvers(); this.#transport.websocket.addEventListener("close", () => { logger().debug("ws closed"); resolve(undefined); }); this.#transport.websocket.close(); await promise; } else if ("sse" in this.#transport) { this.#transport.sse.close(); } else { assertUnreachable(this.#transport); } this.#transport = undefined; } #sendSubscription(eventName: string, subscribe: boolean) { this.#sendMessage( { b: { sr: { e: eventName, s: subscribe, }, }, }, { ephemeral: true }, ); } } /** * Connection to a actor. Allows calling actor's remote procedure calls with inferred types. See {@link ActorConnRaw} for underlying methods. * * @example * ``` * const room = client.connect<ChatRoom>(...etc...); * // This calls the action named `sendMessage` on the `ChatRoom` actor. * await room.sendMessage('Hello, world!'); * ``` * * Private methods (e.g. those starting with `_`) are automatically excluded. * * @template AD The actor class that this connection is for. * @see {@link ActorConnRaw} */ export type ActorConn<AD extends AnyActorDefinition> = ActorConnRaw & ActorDefinitionActions<AD>;