UNPKG

rivetkit

Version:

Lightweight libraries for building stateful actors on edge platforms

607 lines (582 loc) 18.3 kB
import * as cbor from "cbor-x"; import type { Context as HonoContext } from "hono"; import invariant from "invariant"; import type { WebSocket } from "ws"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/actor/utils"; import { ActorError as ClientActorError } from "@/client/errors"; import type { Transport } from "@/client/mod"; import { HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING, WS_PROTOCOL_ACTOR, WS_PROTOCOL_CONN_PARAMS, WS_PROTOCOL_ENCODING, WS_PROTOCOL_PATH, WS_PROTOCOL_TARGET, WS_PROTOCOL_TRANSPORT, } from "@/common/actor-router-consts"; import type { UniversalEventSource } from "@/common/eventsource-interface"; import type { DeconstructedError } from "@/common/utils"; import { importWebSocket } from "@/common/websocket"; import { type ActorOutput, type CreateInput, type GetForIdInput, type GetOrCreateWithKeyInput, type GetWithKeyInput, HEADER_ACTOR_ID, type ManagerDisplayInformation, type ManagerDriver, } from "@/driver-helpers/mod"; import type { ActorQuery } from "@/manager/protocol/query"; import type { UniversalWebSocket } from "@/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { logger } from "./log"; export interface TestInlineDriverCallRequest { encoding: Encoding; transport: Transport; method: string; args: unknown[]; } export type TestInlineDriverCallResponse<T> = | { ok: T; } | { err: DeconstructedError; }; /** * Creates a client driver used for testing the inline client driver. This will send a request to the HTTP server which will then internally call the internal client and return the response. */ export function createTestInlineClientDriver( endpoint: string, encoding: Encoding, transport: Transport, ): ManagerDriver { return { getForId(input: GetForIdInput): Promise<ActorOutput | undefined> { return makeInlineRequest(endpoint, encoding, transport, "getForId", [ input, ]); }, getWithKey(input: GetWithKeyInput): Promise<ActorOutput | undefined> { return makeInlineRequest(endpoint, encoding, transport, "getWithKey", [ input, ]); }, getOrCreateWithKey(input: GetOrCreateWithKeyInput): Promise<ActorOutput> { return makeInlineRequest( endpoint, encoding, transport, "getOrCreateWithKey", [input], ); }, createActor(input: CreateInput): Promise<ActorOutput> { return makeInlineRequest(endpoint, encoding, transport, "createActor", [ input, ]); }, async sendRequest( actorId: string, actorRequest: Request, ): Promise<Response> { // Normalize path to match other drivers const oldUrl = new URL(actorRequest.url); const normalizedPath = oldUrl.pathname.startsWith("/") ? oldUrl.pathname.slice(1) : oldUrl.pathname; const pathWithQuery = normalizedPath + oldUrl.search; logger().debug({ msg: "sending raw http request via test inline driver", actorId, encoding, path: pathWithQuery, }); // Use the dedicated raw HTTP endpoint const url = `${endpoint}/.test/inline-driver/send-request/${pathWithQuery}`; logger().debug({ msg: "rewriting http url", from: oldUrl, to: url }); // Merge headers with our metadata const headers = new Headers(actorRequest.headers); headers.set(HEADER_ACTOR_ID, actorId); // Forward the request directly const response = await fetch( new Request(url, { method: actorRequest.method, headers, body: actorRequest.body, signal: actorRequest.signal, duplex: "half", } as RequestInit), ); // Check if it's an error response from our handler if ( !response.ok && response.headers.get("content-type")?.includes("application/json") ) { try { // Clone the response to avoid consuming the body const clonedResponse = response.clone(); const errorData = (await clonedResponse.json()) as any; if (errorData.error) { // Handle both error formats: // 1. { error: { code, message, metadata } } - structured format // 2. { error: "message" } - simple string format (from custom onFetch handlers) if (typeof errorData.error === "object") { throw new ClientActorError( errorData.error.code, errorData.error.message, errorData.error.metadata, ); } // For simple string errors, just return the response as-is // This allows custom onFetch handlers to return their own error formats } } catch (e) { // If it's not our error format, just return the response as-is if (!(e instanceof ClientActorError)) { return response; } throw e; } } return response; }, async openWebSocket( path: string, actorId: string, encoding: Encoding, params: unknown, connId?: string, connToken?: string, ): Promise<UniversalWebSocket> { const WebSocket = await importWebSocket(); // Normalize path to match other drivers const normalizedPath = path.startsWith("/") ? path.slice(1) : path; // Create WebSocket connection to the test endpoint const wsUrl = new URL( `${endpoint}/.test/inline-driver/connect-websocket/ws`, ); logger().debug({ msg: "creating websocket connection via test inline driver", url: wsUrl.toString(), }); // Convert http/https to ws/wss const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}`; logger().debug({ msg: "connecting to websocket", url: finalWsUrl }); // Build protocols for the connection const protocols: string[] = []; protocols.push(`${WS_PROTOCOL_TARGET}actor`); protocols.push(`${WS_PROTOCOL_ACTOR}${actorId}`); protocols.push(`${WS_PROTOCOL_ENCODING}${encoding}`); protocols.push(`${WS_PROTOCOL_TRANSPORT}${transport}`); protocols.push( `${WS_PROTOCOL_PATH}${encodeURIComponent(normalizedPath)}`, ); if (params !== undefined) { protocols.push( `${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`, ); } // Create and return the WebSocket // Node & browser WebSocket types are incompatible const ws = new WebSocket(finalWsUrl, protocols) as any; return ws; }, async proxyRequest( c: HonoContext, actorRequest: Request, actorId: string, ): Promise<Response> { return await this.sendRequest(actorId, actorRequest); }, proxyWebSocket( _c: HonoContext, _path: string, _actorId: string, _encoding: Encoding, _params: unknown, ): Promise<Response> { throw "UNIMPLEMENTED"; // const upgradeWebSocket = this.#runConfig.getUpgradeWebSocket?.(); // invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); // // const wsHandler = this.openWebSocket(path, actorId, encoding, connParams); // return upgradeWebSocket(() => wsHandler)(c, noopNext()); }, displayInformation(): ManagerDisplayInformation { return { name: "Test Inline", properties: {} }; }, // TODO: getOrCreateInspectorAccessToken: () => "", // action: async <Args extends Array<unknown> = unknown[], Response = unknown>( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encoding: Encoding, // params: unknown, // name: string, // args: Args, // ): Promise<Response> => { // return makeInlineRequest<Response>( // endpoint, // encoding, // transport, // "action", // [undefined, actorQuery, encoding, params, name, args], // ); // }, // // resolveActorId: async ( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encodingKind: Encoding, // params: unknown, // ): Promise<string> => { // return makeInlineRequest<string>( // endpoint, // encodingKind, // transport, // "resolveActorId", // [undefined, actorQuery, encodingKind, params], // ); // }, // // connectWebSocket: async ( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encodingKind: Encoding, // params: unknown, // ): Promise<WebSocket> => { // const WebSocket = await importWebSocket(); // // logger().debug({ // msg: "creating websocket connection via test inline driver", // actorQuery, // encodingKind, // }); // // // Create WebSocket connection to the test endpoint // const wsUrl = new URL( // `${endpoint}/registry/.test/inline-driver/connect-websocket`, // ); // wsUrl.searchParams.set("actorQuery", JSON.stringify(actorQuery)); // if (params !== undefined) // wsUrl.searchParams.set("params", JSON.stringify(params)); // wsUrl.searchParams.set("encodingKind", encodingKind); // // // Convert http/https to ws/wss // const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; // const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}${wsUrl.search}`; // // logger().debug({ msg: "connecting to websocket", url: finalWsUrl }); // // // Create and return the WebSocket // // Node & browser WebSocket types are incompatible // const ws = new WebSocket(finalWsUrl, [ // // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts // "rivetkit", // ]) as any; // // return ws; // }, // // connectSse: async ( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encodingKind: Encoding, // params: unknown, // ): Promise<UniversalEventSource> => { // logger().debug({ // msg: "creating sse connection via test inline driver", // actorQuery, // encodingKind, // params, // }); // // // Dynamically import EventSource if needed // const EventSourceImport = await import("eventsource"); // // Handle both ES modules (default) and CommonJS export patterns // const EventSourceConstructor = // (EventSourceImport as any).default || EventSourceImport; // // // Encode parameters for the URL // const actorQueryParam = encodeURIComponent(JSON.stringify(actorQuery)); // const encodingParam = encodeURIComponent(encodingKind); // const paramsParam = params // ? encodeURIComponent(JSON.stringify(params)) // : null; // // // Create SSE connection URL // const sseUrl = new URL( // `${endpoint}/registry/.test/inline-driver/connect-sse`, // ); // sseUrl.searchParams.set("actorQueryRaw", actorQueryParam); // sseUrl.searchParams.set("encodingKind", encodingParam); // if (paramsParam) { // sseUrl.searchParams.set("params", paramsParam); // } // // logger().debug({ msg: "connecting to sse", url: sseUrl.toString() }); // // // Create and return the EventSource // const eventSource = new EventSourceConstructor(sseUrl.toString()); // // // Wait for the connection to be established before returning // await new Promise<void>((resolve, reject) => { // eventSource.onopen = () => { // logger().debug({ msg: "sse connection established" }); // resolve(); // }; // // eventSource.onerror = (event: Event) => { // logger().error({ msg: "sse connection failed", event }); // reject(new Error("Failed to establish SSE connection")); // }; // // // Set a timeout in case the connection never establishes // setTimeout(() => { // if (eventSource.readyState !== EventSourceConstructor.OPEN) { // reject(new Error("SSE connection timed out")); // } // }, 10000); // 10 second timeout // }); // // return eventSource as UniversalEventSource; // }, // // sendHttpMessage: async ( // _c: HonoContext | undefined, // actorId: string, // encoding: Encoding, // connectionId: string, // connectionToken: string, // message: protocol.ToServer, // ): Promise<void> => { // logger().debug({ // msg: "sending http message via test inline driver", // actorId, // encoding, // connectionId, // transport, // }); // // const result = await fetch( // `${endpoint}/registry/.test/inline-driver/call`, // { // method: "POST", // headers: { // "Content-Type": "application/json", // }, // body: JSON.stringify({ // encoding, // transport, // method: "sendHttpMessage", // args: [ // undefined, // actorId, // encoding, // connectionId, // connectionToken, // message, // ], // } satisfies TestInlineDriverCallRequest), // }, // ); // // if (!result.ok) { // throw new Error(`Failed to send HTTP message: ${result.statusText}`); // } // // // Discard response // await result.body?.cancel(); // }, // // rawHttpRequest: async ( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encoding: Encoding, // params: unknown, // path: string, // init: RequestInit, // ): Promise<Response> => { // // Normalize path to match other drivers // const normalizedPath = path.startsWith("/") ? path.slice(1) : path; // // logger().debug({ // msg: "sending raw http request via test inline driver", // actorQuery, // encoding, // path: normalizedPath, // }); // // // Use the dedicated raw HTTP endpoint // const url = `${endpoint}/registry/.test/inline-driver/raw-http/${normalizedPath}`; // // logger().debug({ msg: "rewriting http url", from: path, to: url }); // // // Merge headers with our metadata // const headers = new Headers(init.headers); // headers.set(HEADER_ACTOR_QUERY, JSON.stringify(actorQuery)); // headers.set(HEADER_ENCODING, encoding); // if (params !== undefined) { // headers.set(HEADER_CONN_PARAMS, JSON.stringify(params)); // } // // // Forward the request directly // const response = await fetch(url, { // ...init, // headers, // }); // // // Check if it's an error response from our handler // if ( // !response.ok && // response.headers.get("content-type")?.includes("application/json") // ) { // try { // // Clone the response to avoid consuming the body // const clonedResponse = response.clone(); // const errorData = (await clonedResponse.json()) as any; // if (errorData.error) { // // Handle both error formats: // // 1. { error: { code, message, metadata } } - structured format // // 2. { error: "message" } - simple string format (from custom onFetch handlers) // if (typeof errorData.error === "object") { // throw new ClientActorError( // errorData.error.code, // errorData.error.message, // errorData.error.metadata, // ); // } // // For simple string errors, just return the response as-is // // This allows custom onFetch handlers to return their own error formats // } // } catch (e) { // // If it's not our error format, just return the response as-is // if (!(e instanceof ClientActorError)) { // return response; // } // throw e; // } // } // // return response; // }, // // rawWebSocket: async ( // _c: HonoContext | undefined, // actorQuery: ActorQuery, // encoding: Encoding, // params: unknown, // path: string, // protocols: string | string[] | undefined, // ): Promise<WebSocket> => { // logger().debug({ msg: "test inline driver rawWebSocket called" }); // const WebSocket = await importWebSocket(); // // // Normalize path to match other drivers // const normalizedPath = path.startsWith("/") ? path.slice(1) : path; // // logger().debug({ // msg: "creating raw websocket connection via test inline driver", // actorQuery, // encoding, // path: normalizedPath, // protocols, // }); // // // Create WebSocket connection to the test endpoint // const wsUrl = new URL( // `${endpoint}/registry/.test/inline-driver/raw-websocket`, // ); // wsUrl.searchParams.set("actorQuery", JSON.stringify(actorQuery)); // if (params !== undefined) // wsUrl.searchParams.set("params", JSON.stringify(params)); // wsUrl.searchParams.set("encodingKind", encoding); // wsUrl.searchParams.set("path", normalizedPath); // if (protocols !== undefined) // wsUrl.searchParams.set("protocols", JSON.stringify(protocols)); // // // Convert http/https to ws/wss // const wsProtocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; // const finalWsUrl = `${wsProtocol}//${wsUrl.host}${wsUrl.pathname}${wsUrl.search}`; // // logger().debug({ msg: "connecting to raw websocket", url: finalWsUrl }); // // logger().debug({ // msg: "rewriting websocket url", // from: path, // to: finalWsUrl, // }); // // // Create and return the WebSocket // // Node & browser WebSocket types are incompatible // const ws = new WebSocket(finalWsUrl, [ // // HACK: See packages/drivers/cloudflare-workers/src/websocket.ts // "rivetkit", // ]) as any; // // logger().debug({ // msg: "test inline driver created websocket", // readyState: ws.readyState, // url: ws.url, // }); // // return ws; // }, } satisfies ManagerDriver; } async function makeInlineRequest<T>( endpoint: string, encoding: Encoding, transport: Transport, method: string, args: unknown[], ): Promise<T> { logger().debug({ msg: "sending inline request", encoding, transport, method, args, }); // Call driver const response = await fetch(`${endpoint}/.test/inline-driver/call`, { method: "POST", headers: { "Content-Type": "application/json", }, body: cbor.encode({ encoding, transport, method, args, } satisfies TestInlineDriverCallRequest), duplex: "half", } as RequestInit); if (!response.ok) { throw new Error(`Failed to call inline ${method}: ${response.statusText}`); } // Parse response const buffer = await response.arrayBuffer(); const callResponse: TestInlineDriverCallResponse<T> = cbor.decode( new Uint8Array(buffer), ); // Throw or OK if ("ok" in callResponse) { return callResponse.ok; } else if ("err" in callResponse) { throw new ClientActorError( callResponse.err.group, callResponse.err.code, callResponse.err.message, callResponse.err.metadata, ); } else { assertUnreachable(callResponse); } }