UNPKG

@rivetkit/core

Version:

737 lines (658 loc) 18.5 kB
import type { Context as HonoContext, HonoRequest } from "hono"; import { type SSEStreamingApi, streamSSE } from "hono/streaming"; import type { WSContext } from "hono/ws"; import { ActionContext } from "@/actor/action"; import type { AnyConn } from "@/actor/connection"; import { generateConnId, generateConnToken } from "@/actor/connection"; import * as errors from "@/actor/errors"; import type { AnyActorInstance } from "@/actor/instance"; import * as protoHttpAction from "@/actor/protocol/http/action"; import { parseMessage } from "@/actor/protocol/message/mod"; import type * as messageToServer from "@/actor/protocol/message/to-server"; import type { InputData } from "@/actor/protocol/serde"; import { deserialize, type Encoding, EncodingSchema, serialize, } from "@/actor/protocol/serde"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; import { deconstructError, stringifyError } from "@/common/utils"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { HonoWebSocketAdapter } from "@/manager/hono-websocket-adapter"; import type { RunConfig } from "@/registry/run-config"; import type { ActorDriver } from "./driver"; import { CONN_DRIVER_GENERIC_HTTP, CONN_DRIVER_GENERIC_SSE, CONN_DRIVER_GENERIC_WEBSOCKET, type GenericHttpDriverState, type GenericSseDriverState, type GenericWebSocketDriverState, } from "./generic-conn-driver"; import { logger } from "./log"; import { assertUnreachable } from "./utils"; export interface ConnectWebSocketOpts { req?: HonoRequest; encoding: Encoding; actorId: string; params: unknown; authData: unknown; } export interface ConnectWebSocketOutput { onOpen: (ws: WSContext) => void; onMessage: (message: messageToServer.ToServer) => void; onClose: () => void; } export interface ConnectSseOpts { req?: HonoRequest; encoding: Encoding; params: unknown; actorId: string; authData: unknown; } export interface ConnectSseOutput { onOpen: (stream: SSEStreamingApi) => void; onClose: () => Promise<void>; } export interface ActionOpts { req?: HonoRequest; params: unknown; actionName: string; actionArgs: unknown[]; actorId: string; authData: unknown; } export interface ActionOutput { output: unknown; } export interface ConnsMessageOpts { req?: HonoRequest; connId: string; connToken: string; message: messageToServer.ToServer; actorId: string; } export interface FetchOpts { request: Request; actorId: string; authData: unknown; } export interface WebSocketOpts { request: Request; websocket: UniversalWebSocket; actorId: string; authData: unknown; } /** * Creates a WebSocket connection handler */ export async function handleWebSocketConnect( c: HonoContext | undefined, runConfig: RunConfig, actorDriver: ActorDriver, actorId: string, encoding: Encoding, parameters: unknown, authData: unknown, ): Promise<UpgradeWebSocketArgs> { const exposeInternalError = c ? getRequestExposeInternalError(c.req) : false; // Setup promise for the init handlers since all other behavior depends on this const { promise: handlersPromise, resolve: handlersResolve, reject: handlersReject, } = Promise.withResolvers<{ conn: AnyConn; actor: AnyActorInstance; connId: string; }>(); // Pre-load the actor to catch errors early let actor: AnyActorInstance; try { actor = await actorDriver.loadActor(actorId); } catch (error) { // Return handler that immediately closes with error return { onOpen: (_evt: any, ws: WSContext) => { const { code } = deconstructError( error, logger(), { wsEvent: "open", }, exposeInternalError, ); ws.close(1011, code); }, onMessage: (_evt: { data: any }, ws: WSContext) => { ws.close(1011, "Actor not loaded"); }, onClose: (_event: any, _ws: WSContext) => {}, onError: (_error: unknown) => {}, }; } return { onOpen: (_evt: any, ws: WSContext) => { logger().debug("websocket open"); // Run async operations in background (async () => { try { const connId = generateConnId(); const connToken = generateConnToken(); const connState = await actor.prepareConn(parameters, c?.req.raw); // Save socket const connGlobalState = actorDriver.getGenericConnGlobalState(actorId); connGlobalState.websockets.set(connId, ws); logger().debug("registered websocket for conn", { actorId, totalCount: connGlobalState.websockets.size, }); // Create connection const conn = await actor.createConn( connId, connToken, parameters, connState, CONN_DRIVER_GENERIC_WEBSOCKET, { encoding } satisfies GenericWebSocketDriverState, authData, ); // Unblock other handlers handlersResolve({ conn, actor, connId }); } catch (error) { handlersReject(error); const { code } = deconstructError( error, logger(), { wsEvent: "open", }, exposeInternalError, ); ws.close(1011, code); } })(); }, onMessage: (evt: { data: any }, ws: WSContext) => { // Handle message asynchronously handlersPromise .then(({ conn, actor }) => { logger().debug("received message"); const value = evt.data.valueOf() as InputData; parseMessage(value, { encoding: encoding, maxIncomingMessageSize: runConfig.maxIncomingMessageSize, }) .then((message) => { actor.processMessage(message, conn).catch((error) => { const { code } = deconstructError( error, logger(), { wsEvent: "message", }, exposeInternalError, ); ws.close(1011, code); }); }) .catch((error) => { const { code } = deconstructError( error, logger(), { wsEvent: "message", }, exposeInternalError, ); ws.close(1011, code); }); }) .catch((error) => { const { code } = deconstructError( error, logger(), { wsEvent: "message", }, exposeInternalError, ); ws.close(1011, code); }); }, onClose: ( event: { wasClean: boolean; code: number; reason: string; }, ws: WSContext, ) => { if (event.wasClean) { logger().info("websocket closed", { code: event.code, reason: event.reason, wasClean: event.wasClean, }); } else { logger().warn("websocket closed", { code: event.code, reason: event.reason, wasClean: event.wasClean, }); } // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state // https://github.com/cloudflare/workerd/issues/2569 ws.close(1000, "hack_force_close"); // Handle cleanup asynchronously handlersPromise .then(({ conn, actor, connId }) => { const connGlobalState = actorDriver.getGenericConnGlobalState(actorId); const didDelete = connGlobalState.websockets.delete(connId); if (didDelete) { logger().info("removing websocket for conn", { totalCount: connGlobalState.websockets.size, }); } else { logger().warn("websocket does not exist for conn", { actorId, totalCount: connGlobalState.websockets.size, }); } actor.__removeConn(conn); }) .catch((error) => { deconstructError( error, logger(), { wsEvent: "close" }, exposeInternalError, ); }); }, onError: (_error: unknown) => { try { // Actors don't need to know about this, since it's abstracted away logger().warn("websocket error"); } catch (error) { deconstructError( error, logger(), { wsEvent: "error" }, exposeInternalError, ); } }, }; } /** * Creates an SSE connection handler */ export async function handleSseConnect( c: HonoContext, runConfig: RunConfig, actorDriver: ActorDriver, actorId: string, authData: unknown, ) { const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req); // Return the main handler with all async work inside return streamSSE(c, async (stream) => { let actor: AnyActorInstance | undefined; let connId: string | undefined; let connToken: string | undefined; let connState: unknown; let conn: AnyConn | undefined; try { // Do all async work inside the handler actor = await actorDriver.loadActor(actorId); connId = generateConnId(); connToken = generateConnToken(); connState = await actor.prepareConn(parameters, c.req.raw); logger().debug("sse open"); // Save stream actorDriver .getGenericConnGlobalState(actorId) .sseStreams.set(connId, stream); // Create connection conn = await actor.createConn( connId, connToken, parameters, connState, CONN_DRIVER_GENERIC_SSE, { encoding } satisfies GenericSseDriverState, authData, ); // HACK: This is required so the abort handler below works // // See https://github.com/honojs/hono/issues/1770#issuecomment-2461966225 stream.onAbort(() => {}); // Wait for close const abortResolver = Promise.withResolvers(); c.req.raw.signal.addEventListener("abort", async () => { try { logger().debug("sse shutting down"); // Cleanup if (connId) { actorDriver .getGenericConnGlobalState(actorId) .sseStreams.delete(connId); } if (conn && actor) { actor.__removeConn(conn); } abortResolver.resolve(undefined); } catch (error) { logger().error("error closing sse connection", { error }); } }); // HACK: Will throw if not configured try { c.executionCtx.waitUntil(abortResolver.promise); } catch {} // Wait until connection aborted await abortResolver.promise; } catch (error) { logger().error("error in sse connection", { error }); // Cleanup on error if (connId !== undefined) { actorDriver .getGenericConnGlobalState(actorId) .sseStreams.delete(connId); } if (conn && actor !== undefined) { actor.__removeConn(conn); } // Close the stream on error stream.close(); } }); } /** * Creates an action handler */ export async function handleAction( c: HonoContext, runConfig: RunConfig, actorDriver: ActorDriver, actionName: string, actorId: string, authData: unknown, ) { const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req); logger().debug("handling action", { actionName, encoding }); // Validate incoming request let body: unknown; if (encoding === "json") { try { body = await c.req.json(); } catch (err) { if (err instanceof errors.InvalidActionRequest) { throw err; } throw new errors.InvalidActionRequest( `Invalid JSON: ${stringifyError(err)}`, ); } } else if (encoding === "cbor") { try { const value = await c.req.arrayBuffer(); const uint8Array = new Uint8Array(value); body = await deserialize(uint8Array as unknown as InputData, encoding); } catch (err) { throw new errors.InvalidActionRequest( `Invalid binary format: ${stringifyError(err)}`, ); } } else { return assertUnreachable(encoding); } // Validate using the action schema let actionArgs: unknown[]; try { const result = protoHttpAction.ActionRequestSchema.safeParse(body); if (!result.success) { throw new errors.InvalidActionRequest("Invalid action request format"); } actionArgs = result.data.a; } catch (err) { throw new errors.InvalidActionRequest( `Invalid schema: ${stringifyError(err)}`, ); } // Invoke the action let actor: AnyActorInstance | undefined; let conn: AnyConn | undefined; let output: unknown | undefined; try { actor = await actorDriver.loadActor(actorId); // Create conn const connState = await actor.prepareConn(parameters, c.req.raw); conn = await actor.createConn( generateConnId(), generateConnToken(), parameters, connState, CONN_DRIVER_GENERIC_HTTP, {} satisfies GenericHttpDriverState, authData, ); // Call action const ctx = new ActionContext(actor.actorContext!, conn!); output = await actor.executeAction(ctx, actionName, actionArgs); } finally { if (conn) { actor?.__removeConn(conn); } } // Encode the response if (encoding === "json") { const responseData = { o: output, // Use the format expected by ResponseOkSchema }; return c.json(responseData); } else if (encoding === "cbor") { // Use serialize from serde.ts instead of custom encoder const responseData = { o: output, // Use the format expected by ResponseOkSchema }; const serialized = serialize(responseData, encoding); return c.body(serialized as Uint8Array, 200, { "Content-Type": "application/octet-stream", }); } else { return assertUnreachable(encoding); } } /** * Create a connection message handler */ export async function handleConnectionMessage( c: HonoContext, runConfig: RunConfig, actorDriver: ActorDriver, connId: string, connToken: string, actorId: string, ) { const encoding = getRequestEncoding(c.req); // Validate incoming request let message: messageToServer.ToServer; if (encoding === "json") { try { message = await c.req.json(); } catch (_err) { throw new errors.InvalidRequest("Invalid JSON"); } } else if (encoding === "cbor") { try { const value = await c.req.arrayBuffer(); const uint8Array = new Uint8Array(value); message = await parseMessage(uint8Array as unknown as InputData, { encoding, maxIncomingMessageSize: runConfig.maxIncomingMessageSize, }); } catch (err) { throw new errors.InvalidRequest( `Invalid binary format: ${stringifyError(err)}`, ); } } else { return assertUnreachable(encoding); } const actor = await actorDriver.loadActor(actorId); // Find connection const conn = actor.conns.get(connId); if (!conn) { throw new errors.ConnNotFound(connId); } // Authenticate connection if (conn._token !== connToken) { throw new errors.IncorrectConnToken(); } // Process message await actor.processMessage(message, conn); return c.json({}); } export async function handleRawWebSocketHandler( c: HonoContext | undefined, path: string, actorDriver: ActorDriver, actorId: string, authData: unknown, ): Promise<UpgradeWebSocketArgs> { const actor = await actorDriver.loadActor(actorId); // Return WebSocket event handlers return { onOpen: (_evt: any, ws: any) => { // Wrap the Hono WebSocket in our adapter const adapter = new HonoWebSocketAdapter(ws); // Store adapter reference on the WebSocket for event handlers (ws as any).__adapter = adapter; // Extract the path after prefix and preserve query parameters // Use URL API for cleaner parsing const url = new URL(path, "http://actor"); const pathname = url.pathname.replace(/^\/raw\/websocket/, "") || "/"; const normalizedPath = pathname + url.search; let newRequest: Request; if (c) { newRequest = new Request(`http://actor${normalizedPath}`, c.req.raw); } else { newRequest = new Request(`http://actor${normalizedPath}`, { method: "GET", }); } logger().debug("rewriting websocket url", { from: path, to: newRequest.url, }); // Call the actor's onWebSocket handler with the adapted WebSocket actor.handleWebSocket(adapter, { request: newRequest, auth: authData, }); }, onMessage: (event: any, ws: any) => { // Find the adapter for this WebSocket const adapter = (ws as any).__adapter; if (adapter) { adapter._handleMessage(event); } }, onClose: (evt: any, ws: any) => { // Find the adapter for this WebSocket const adapter = (ws as any).__adapter; if (adapter) { adapter._handleClose(evt?.code || 1006, evt?.reason || ""); } }, onError: (error: any, ws: any) => { // Find the adapter for this WebSocket const adapter = (ws as any).__adapter; if (adapter) { adapter._handleError(error); } }, }; } // Helper to get the connection encoding from a request export function getRequestEncoding(req: HonoRequest): Encoding { const encodingParam = req.header(HEADER_ENCODING); if (!encodingParam) { throw new errors.InvalidEncoding("undefined"); } const result = EncodingSchema.safeParse(encodingParam); if (!result.success) { throw new errors.InvalidEncoding(encodingParam as string); } return result.data; } export function getRequestExposeInternalError(req: HonoRequest): boolean { const param = req.header(HEADER_EXPOSE_INTERNAL_ERROR); if (!param) { return false; } return param === "true"; } export function getRequestQuery(c: HonoContext): unknown { // Get query parameters for actor lookup const queryParam = c.req.header(HEADER_ACTOR_QUERY); if (!queryParam) { logger().error("missing query parameter"); throw new errors.InvalidRequest("missing query"); } // Parse the query JSON and validate with schema try { const parsed = JSON.parse(queryParam); return parsed; } catch (error) { logger().error("invalid query json", { error }); throw new errors.InvalidQueryJSON(error); } } export const HEADER_ACTOR_QUERY = "X-RivetKit-Query"; export const HEADER_ENCODING = "X-RivetKit-Encoding"; // Internal header export const HEADER_EXPOSE_INTERNAL_ERROR = "X-RivetKit-Expose-Internal-Error"; // IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params. export const HEADER_CONN_PARAMS = "X-RivetKit-Conn-Params"; // Internal header export const HEADER_AUTH_DATA = "X-RivetKit-Auth-Data"; export const HEADER_ACTOR_ID = "X-RivetKit-Actor"; export const HEADER_CONN_ID = "X-RivetKit-Conn"; export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token"; /** * Headers that publics can send from public clients. * * Used for CORS. **/ export const ALLOWED_PUBLIC_HEADERS = [ "Content-Type", "User-Agent", HEADER_ACTOR_QUERY, HEADER_ENCODING, HEADER_CONN_PARAMS, HEADER_ACTOR_ID, HEADER_CONN_ID, HEADER_CONN_TOKEN, ]; // Helper to get connection parameters for the request export function getRequestConnParams(req: HonoRequest): unknown { const paramsParam = req.header(HEADER_CONN_PARAMS); if (!paramsParam) { return null; } try { return JSON.parse(paramsParam); } catch (err) { throw new errors.InvalidParams( `Invalid params JSON: ${stringifyError(err)}`, ); } }