@rivetkit/core
Version:
1,746 lines (1,574 loc) • 47.5 kB
text/typescript
import { createRoute, OpenAPIHono } from "@hono/zod-openapi";
import * as cbor from "cbor-x";
import {
Hono,
type Context as HonoContext,
type MiddlewareHandler,
} from "hono";
import { cors } from "hono/cors";
import { streamSSE } from "hono/streaming";
import type { WSContext } from "hono/ws";
import invariant from "invariant";
import type { CloseEvent, MessageEvent, WebSocket } from "ws";
import { z } from "zod";
import * as errors from "@/actor/errors";
import type * as protoHttpResolve from "@/actor/protocol/http/resolve";
import type { Transport } from "@/actor/protocol/message/mod";
import type { ToClient } from "@/actor/protocol/message/to-client";
import { type Encoding, serialize } from "@/actor/protocol/serde";
import {
PATH_CONNECT_WEBSOCKET,
PATH_RAW_WEBSOCKET_PREFIX,
} from "@/actor/router";
import {
ALLOWED_PUBLIC_HEADERS,
getRequestEncoding,
getRequestQuery,
HEADER_ACTOR_ID,
HEADER_ACTOR_QUERY,
HEADER_AUTH_DATA,
HEADER_CONN_ID,
HEADER_CONN_PARAMS,
HEADER_CONN_TOKEN,
HEADER_ENCODING,
} from "@/actor/router-endpoints";
import type { ClientDriver } from "@/client/client";
import {
handleRouteError,
handleRouteNotFound,
loggerMiddleware,
} from "@/common/router";
import {
type DeconstructedError,
deconstructError,
noopNext,
stringifyError,
} from "@/common/utils";
import { createManagerInspectorRouter } from "@/inspector/manager";
import { secureInspector } from "@/inspector/utils";
import type { UpgradeWebSocketArgs } from "@/mod";
import type { RegistryConfig } from "@/registry/config";
import type { RunConfig } from "@/registry/run-config";
import { VERSION } from "@/utils";
import { authenticateEndpoint } from "./auth";
import type { ManagerDriver } from "./driver";
import { logger } from "./log";
import type { ActorQuery } from "./protocol/query";
import {
ActorQuerySchema,
ConnectRequestSchema,
ConnectWebSocketRequestSchema,
ConnMessageRequestSchema,
ResolveRequestSchema,
} from "./protocol/query";
/**
* Parse WebSocket protocol headers for query and connection parameters
*/
function parseWebSocketProtocols(protocols: string | undefined): {
queryRaw: string | undefined;
encodingRaw: string | undefined;
connParamsRaw: string | undefined;
} {
let queryRaw: string | undefined;
let encodingRaw: string | undefined;
let connParamsRaw: string | undefined;
if (protocols) {
const protocolList = protocols.split(",").map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith("query.")) {
queryRaw = decodeURIComponent(protocol.substring("query.".length));
} else if (protocol.startsWith("encoding.")) {
encodingRaw = protocol.substring("encoding.".length);
} else if (protocol.startsWith("conn_params.")) {
connParamsRaw = decodeURIComponent(
protocol.substring("conn_params.".length),
);
}
}
}
return { queryRaw, encodingRaw, connParamsRaw };
}
const OPENAPI_ENCODING = z.string().openapi({
description: "The encoding format to use for the response (json, cbor)",
example: "json",
});
const OPENAPI_ACTOR_QUERY = z.string().openapi({
description: "Actor query information",
});
const OPENAPI_CONN_PARAMS = z.string().openapi({
description: "Connection parameters",
});
const OPENAPI_ACTOR_ID = z.string().openapi({
description: "Actor ID (used in some endpoints)",
example: "actor-123456",
});
const OPENAPI_CONN_ID = z.string().openapi({
description: "Connection ID",
example: "conn-123456",
});
const OPENAPI_CONN_TOKEN = z.string().openapi({
description: "Connection token",
});
function buildOpenApiResponses<T>(schema: T, validateBody: boolean) {
return {
200: {
description: "Success",
content: validateBody
? {
"application/json": {
schema,
},
}
: {},
},
400: {
description: "User error",
},
500: {
description: "Internal error",
},
};
}
/**
* Only use `validateBody` to `true` if you need to export OpenAPI JSON.
*
* If left enabled for production, this will cause errors. We disable JSON validation since:
* - It prevents us from proxying requests, since validating the body requires consuming the body so we can't forward the body
* - We validate all types at the actor router layer since most requests are proxied
*/
export function createManagerRouter(
registryConfig: RegistryConfig,
runConfig: RunConfig,
inlineClientDriver: ClientDriver,
managerDriver: ManagerDriver,
validateBody: boolean,
): { router: Hono; openapi: OpenAPIHono } {
const router = new OpenAPIHono({ strict: false }).basePath(
runConfig.basePath,
);
router.use("*", loggerMiddleware(logger()));
if (runConfig.cors || runConfig.studio?.cors) {
router.use("*", async (c, next) => {
// Don't apply to WebSocket routes
// HACK: This could be insecure if we had a varargs path. We have to check the path suffix for WS since we don't know the path that this router was mounted.
// HACK: Checking "/websocket/" is not safe, but there is no other way to handle this if we don't know the base path this is
// mounted on
const path = c.req.path;
if (
path.endsWith("/actors/connect/websocket") ||
path.includes("/actors/raw/websocket/") ||
// inspectors implement their own CORS handling
path.endsWith("/inspect") ||
path.endsWith("/actors/inspect")
) {
return next();
}
return cors({
...(runConfig.cors ?? {}),
...(runConfig.studio?.cors ?? {}),
origin: (origin, c) => {
const studioOrigin = runConfig.studio?.cors?.origin;
if (studioOrigin !== undefined) {
if (typeof studioOrigin === "function") {
const allowed = studioOrigin(origin, c);
if (allowed) return allowed;
// Proceed to next CORS config if none provided
} else if (Array.isArray(studioOrigin)) {
return studioOrigin.includes(origin) ? origin : undefined;
} else {
return studioOrigin;
}
}
if (runConfig.cors?.origin !== undefined) {
if (typeof runConfig.cors.origin === "function") {
const allowed = runConfig.cors.origin(origin, c);
if (allowed) return allowed;
} else {
return runConfig.cors.origin as string;
}
}
return null;
},
allowMethods: (origin, c) => {
const studioMethods = runConfig.studio?.cors?.allowMethods;
if (studioMethods) {
if (typeof studioMethods === "function") {
return studioMethods(origin, c);
}
return studioMethods;
}
if (runConfig.cors?.allowMethods) {
if (typeof runConfig.cors.allowMethods === "function") {
return runConfig.cors.allowMethods(origin, c);
}
return runConfig.cors.allowMethods;
}
return [];
},
allowHeaders: [
...(runConfig.cors?.allowHeaders ?? []),
...(runConfig.studio?.cors?.allowHeaders ?? []),
...ALLOWED_PUBLIC_HEADERS,
"Content-Type",
"User-Agent",
],
credentials:
runConfig.cors?.credentials ??
runConfig.studio?.cors?.credentials ??
true,
})(c, next);
});
}
// GET /
router.get("/", (c: HonoContext) => {
return c.text(
"This is an RivetKit registry.\n\nLearn more at https://rivetkit.org",
);
});
// POST /actors/resolve
{
const ResolveQuerySchema = z
.object({
query: z.any().openapi({
example: { getForId: { actorId: "actor-123" } },
}),
})
.openapi("ResolveQuery");
const ResolveResponseSchema = z
.object({
i: z.string().openapi({
example: "actor-123",
}),
})
.openapi("ResolveResponse");
const resolveRoute = createRoute({
method: "post",
path: "/actors/resolve",
request: {
body: {
content: validateBody
? {
"application/json": {
schema: ResolveQuerySchema,
},
}
: {},
},
headers: z.object({
[HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY,
}),
},
responses: buildOpenApiResponses(ResolveResponseSchema, validateBody),
});
router.openapi(resolveRoute, (c) =>
handleResolveRequest(c, registryConfig, managerDriver),
);
}
// GET /actors/connect/websocket
{
// HACK: WebSockets don't work with mounts, so we need to dynamically match the trailing path
router.use("*", (c, next) => {
if (c.req.path.endsWith("/actors/connect/websocket")) {
return handleWebSocketConnectRequest(
c,
registryConfig,
runConfig,
managerDriver,
);
}
return next();
});
// This route is a noop, just used to generate docs
const wsRoute = createRoute({
method: "get",
path: "/actors/connect/websocket",
responses: {
101: {
description: "WebSocket upgrade",
},
},
});
router.openapi(wsRoute, () => {
throw new Error("Should be unreachable");
});
}
// GET /actors/connect/sse
{
const sseRoute = createRoute({
method: "get",
path: "/actors/connect/sse",
request: {
headers: z.object({
[HEADER_ENCODING]: OPENAPI_ENCODING,
[HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY,
[HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(),
}),
},
responses: {
200: {
description: "SSE stream",
content: {
"text/event-stream": {
schema: z.unknown(),
},
},
},
},
});
router.openapi(sseRoute, (c) =>
handleSseConnectRequest(c, registryConfig, runConfig, managerDriver),
);
}
// POST /actors/action/:action
{
const ActionParamsSchema = z
.object({
action: z.string().openapi({
param: {
name: "action",
in: "path",
},
example: "myAction",
}),
})
.openapi("ActionParams");
const ActionRequestSchema = z
.object({
query: z.any().openapi({
example: { getForId: { actorId: "actor-123" } },
}),
body: z
.any()
.optional()
.openapi({
example: { param1: "value1", param2: 123 },
}),
})
.openapi("ActionRequest");
const ActionResponseSchema = z.any().openapi("ActionResponse");
const actionRoute = createRoute({
method: "post",
path: "/actors/actions/{action}",
request: {
params: ActionParamsSchema,
body: {
content: validateBody
? {
"application/json": {
schema: ActionRequestSchema,
},
}
: {},
},
headers: z.object({
[HEADER_ENCODING]: OPENAPI_ENCODING,
[HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(),
}),
},
responses: buildOpenApiResponses(ActionResponseSchema, validateBody),
});
router.openapi(actionRoute, (c) =>
handleActionRequest(c, registryConfig, runConfig, managerDriver),
);
}
// POST /actors/message
{
const ConnectionMessageRequestSchema = z
.object({
message: z.any().openapi({
example: { type: "message", content: "Hello, actor!" },
}),
})
.openapi("ConnectionMessageRequest");
const ConnectionMessageResponseSchema = z
.any()
.openapi("ConnectionMessageResponse");
const messageRoute = createRoute({
method: "post",
path: "/actors/message",
request: {
body: {
content: validateBody
? {
"application/json": {
schema: ConnectionMessageRequestSchema,
},
}
: {},
},
headers: z.object({
[HEADER_ACTOR_ID]: OPENAPI_ACTOR_ID,
[HEADER_CONN_ID]: OPENAPI_CONN_ID,
[HEADER_ENCODING]: OPENAPI_ENCODING,
[HEADER_CONN_TOKEN]: OPENAPI_CONN_TOKEN,
}),
},
responses: buildOpenApiResponses(
ConnectionMessageResponseSchema,
validateBody,
),
});
router.openapi(messageRoute, (c) =>
handleMessageRequest(c, registryConfig, runConfig, managerDriver),
);
}
// Raw HTTP endpoints - /actors/raw/http/*
{
const RawHttpRequestBodySchema = z.any().optional().openapi({
description: "Raw request body (can be any content type)",
});
const RawHttpResponseSchema = z.any().openapi({
description: "Raw response from actor's onFetch handler",
});
// Define common route config
const rawHttpRouteConfig = {
path: "/actors/raw/http/*",
request: {
headers: z.object({
[HEADER_ACTOR_QUERY]: OPENAPI_ACTOR_QUERY.optional(),
[HEADER_CONN_PARAMS]: OPENAPI_CONN_PARAMS.optional(),
}),
body: {
content: {
"*/*": {
schema: RawHttpRequestBodySchema,
},
},
},
},
responses: {
200: {
description: "Success - response from actor's onFetch handler",
content: {
"*/*": {
schema: RawHttpResponseSchema,
},
},
},
404: {
description: "Actor does not have an onFetch handler",
},
500: {
description: "Internal server error or invalid response from actor",
},
},
};
// Create routes for each HTTP method
const httpMethods = [
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
] as const;
for (const method of httpMethods) {
const route = createRoute({
method,
...rawHttpRouteConfig,
});
router.openapi(route, async (c) => {
return handleRawHttpRequest(
c,
registryConfig,
runConfig,
managerDriver,
);
});
}
}
// Raw WebSocket endpoint - /actors/raw/websocket/*
{
// HACK: WebSockets don't work with mounts, so we need to dynamically match the trailing path
router.use("*", async (c, next) => {
if (c.req.path.includes("/raw/websocket/")) {
return handleRawWebSocketRequest(
c,
registryConfig,
runConfig,
managerDriver,
);
}
return next();
});
// This route is a noop, just used to generate docs
const rawWebSocketRoute = createRoute({
method: "get",
path: "/actors/raw/websocket/*",
request: {},
responses: {
101: {
description: "WebSocket upgrade successful",
},
400: {
description: "WebSockets not enabled or invalid request",
},
404: {
description: "Actor does not have an onWebSocket handler",
},
},
});
router.openapi(rawWebSocketRoute, () => {
throw new Error("Should be unreachable");
});
}
if (runConfig.studio?.enabled) {
router.route(
"/actors/inspect",
new Hono()
.use(
cors(runConfig.studio.cors),
secureInspector(runConfig),
universalActorProxy({
registryConfig,
runConfig,
driver: managerDriver,
}),
)
.all("/", (c) =>
// this should be handled by the actor proxy, but just in case
c.text("Unreachable.", 404),
),
);
router.route(
"/inspect",
new Hono()
.use(
cors(runConfig.studio.cors),
secureInspector(runConfig),
async (c, next) => {
const inspector = managerDriver.inspector;
invariant(inspector, "inspector not supported on this platform");
c.set("inspector", inspector);
await next();
},
)
.route("/", createManagerInspectorRouter()),
);
}
if (registryConfig.test.enabled) {
// Add HTTP endpoint to test the inline client
//
// We have to do this in a router since this needs to run in the same server as the RivetKit registry. Some test contexts to not run in the same server.
router.post(".test/inline-driver/call", async (c) => {
// TODO: use openapi instead
const buffer = await c.req.arrayBuffer();
const { encoding, transport, method, args }: TestInlineDriverCallRequest =
cbor.decode(new Uint8Array(buffer));
logger().debug("received inline request", {
encoding,
transport,
method,
args,
});
// Forward inline driver request
let response: TestInlineDriverCallResponse<unknown>;
try {
const output = await ((inlineClientDriver as any)[method] as any)(
...args,
);
response = { ok: output };
} catch (rawErr) {
const err = deconstructError(rawErr, logger(), {}, true);
response = { err };
}
return c.body(cbor.encode(response));
});
router.get(".test/inline-driver/connect-websocket", async (c) => {
const upgradeWebSocket = runConfig.getUpgradeWebSocket?.();
invariant(upgradeWebSocket, "websockets not supported on this platform");
return upgradeWebSocket(async (c: any) => {
const {
actorQuery: actorQueryRaw,
params: paramsRaw,
encodingKind,
} = c.req.query() as {
actorQuery: string;
params?: string;
encodingKind: Encoding;
};
const actorQuery = JSON.parse(actorQueryRaw);
const params =
paramsRaw !== undefined ? JSON.parse(paramsRaw) : undefined;
logger().debug("received test inline driver websocket", {
actorQuery,
params,
encodingKind,
});
// Connect to the actor using the inline client driver - this returns a Promise<WebSocket>
const clientWsPromise = inlineClientDriver.connectWebSocket(
undefined,
actorQuery,
encodingKind,
params,
undefined,
);
return await createTestWebSocketProxy(clientWsPromise, "standard");
})(c, noopNext());
});
router.get(".test/inline-driver/raw-websocket", async (c) => {
const upgradeWebSocket = runConfig.getUpgradeWebSocket?.();
invariant(upgradeWebSocket, "websockets not supported on this platform");
return upgradeWebSocket(async (c: any) => {
const {
actorQuery: actorQueryRaw,
params: paramsRaw,
encodingKind,
path,
protocols: protocolsRaw,
} = c.req.query() as {
actorQuery: string;
params?: string;
encodingKind: Encoding;
path: string;
protocols?: string;
};
const actorQuery = JSON.parse(actorQueryRaw);
const params =
paramsRaw !== undefined ? JSON.parse(paramsRaw) : undefined;
const protocols =
protocolsRaw !== undefined ? JSON.parse(protocolsRaw) : undefined;
logger().debug("received test inline driver raw websocket", {
actorQuery,
params,
encodingKind,
path,
protocols,
});
// Connect to the actor using the inline client driver - this returns a Promise<WebSocket>
logger().debug("calling inlineClientDriver.rawWebSocket");
const clientWsPromise = inlineClientDriver.rawWebSocket(
undefined,
actorQuery,
encodingKind,
params,
path,
protocols,
undefined,
);
logger().debug("calling createTestWebSocketProxy");
return await createTestWebSocketProxy(clientWsPromise, "raw");
})(c, noopNext());
});
// Raw HTTP endpoint for test inline driver
router.all(".test/inline-driver/raw-http/*", async (c) => {
// Extract parameters from headers
const actorQueryHeader = c.req.header(HEADER_ACTOR_QUERY);
const paramsHeader = c.req.header(HEADER_CONN_PARAMS);
const encodingHeader = c.req.header(HEADER_ENCODING);
if (!actorQueryHeader || !encodingHeader) {
return c.text("Missing required headers", 400);
}
const actorQuery = JSON.parse(actorQueryHeader);
const params = paramsHeader ? JSON.parse(paramsHeader) : undefined;
const encoding = encodingHeader as Encoding;
// Extract the path after /raw-http/
const fullPath = c.req.path;
const pathOnly =
fullPath.split("/.test/inline-driver/raw-http/")[1] || "";
// Include query string
const url = new URL(c.req.url);
const pathWithQuery = pathOnly + url.search;
logger().debug("received test inline driver raw http", {
actorQuery,
params,
encoding,
path: pathWithQuery,
method: c.req.method,
});
try {
// Forward the request using the inline client driver
const response = await inlineClientDriver.rawHttpRequest(
undefined,
actorQuery,
encoding,
params,
pathWithQuery,
{
method: c.req.method,
headers: c.req.raw.headers,
body: c.req.raw.body,
},
undefined,
);
// Return the response directly
return response;
} catch (error) {
logger().error("error in test inline raw http", {
error: stringifyError(error),
});
// Return error response
const err = deconstructError(error, logger(), {}, true);
return c.json(
{
error: {
code: err.code,
message: err.message,
metadata: err.metadata,
},
},
err.statusCode,
);
}
});
}
managerDriver.modifyManagerRouter?.(
registryConfig,
router as unknown as Hono,
);
// Mount on both / and /registry
//
// We do this because the default requests are to `/registry/*`.
//
// If using `app.fetch` directly in a non-hono router, paths
// might not be truncated so they'll come to this router as
// `/registry/*`. If mounted correctly in Hono, requests will
// come in at the root as `/*`.
const mountedRouter = new Hono();
mountedRouter.route("/", router);
mountedRouter.route("/registry", router);
// IMPORTANT: These must be on `mountedRouter` instead of `router` or else they will not be called.
mountedRouter.notFound(handleRouteNotFound);
mountedRouter.onError(handleRouteError.bind(undefined, {}));
return { router: mountedRouter, openapi: router };
}
export interface TestInlineDriverCallRequest {
encoding: Encoding;
transport: Transport;
method: string;
args: unknown[];
}
export type TestInlineDriverCallResponse<T> =
| {
ok: T;
}
| {
err: DeconstructedError;
};
/**
* Query the manager driver to get or create a actor based on the provided query
*/
export async function queryActor(
c: HonoContext,
query: ActorQuery,
driver: ManagerDriver,
): Promise<{ actorId: string }> {
logger().debug("querying actor", { query });
let actorOutput: { actorId: string };
if ("getForId" in query) {
const output = await driver.getForId({
c,
actorId: query.getForId.actorId,
});
if (!output) throw new errors.ActorNotFound(query.getForId.actorId);
actorOutput = output;
} else if ("getForKey" in query) {
const existingActor = await driver.getWithKey({
c,
name: query.getForKey.name,
key: query.getForKey.key,
});
if (!existingActor) {
throw new errors.ActorNotFound(
`${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`,
);
}
actorOutput = existingActor;
} else if ("getOrCreateForKey" in query) {
const getOrCreateOutput = await driver.getOrCreateWithKey({
c,
name: query.getOrCreateForKey.name,
key: query.getOrCreateForKey.key,
input: query.getOrCreateForKey.input,
region: query.getOrCreateForKey.region,
});
actorOutput = {
actorId: getOrCreateOutput.actorId,
};
} else if ("create" in query) {
const createOutput = await driver.createActor({
c,
name: query.create.name,
key: query.create.key,
input: query.create.input,
region: query.create.region,
});
actorOutput = {
actorId: createOutput.actorId,
};
} else {
throw new errors.InvalidRequest("Invalid query format");
}
logger().debug("actor query result", {
actorId: actorOutput.actorId,
});
return { actorId: actorOutput.actorId };
}
/**
* Creates a WebSocket proxy for test endpoints that forwards messages between server and client WebSockets
*/
async function createTestWebSocketProxy(
clientWsPromise: Promise<WebSocket>,
connectionType: string,
): Promise<UpgradeWebSocketArgs> {
// Store a reference to the resolved WebSocket
let clientWs: WebSocket | null = null;
try {
// Resolve the client WebSocket promise
logger().debug("awaiting client websocket promise");
clientWs = await clientWsPromise;
logger().debug("client websocket promise resolved", {
constructor: clientWs?.constructor.name,
});
} catch (error) {
logger().error(
`failed to establish client ${connectionType} websocket connection`,
{ error },
);
return {
onOpen: (_evt, serverWs) => {
serverWs.close(1011, "Failed to establish connection");
},
onMessage: () => {},
onError: () => {},
onClose: () => {},
};
}
// Create WebSocket proxy handlers to relay messages between client and server
return {
onOpen: (_evt: any, serverWs: WSContext) => {
logger().debug(`test ${connectionType} websocket connection opened`);
// Check WebSocket type
logger().debug("clientWs info", {
constructor: clientWs.constructor.name,
hasAddEventListener: typeof clientWs.addEventListener === "function",
readyState: clientWs.readyState,
});
// Add message handler to forward messages from client to server
clientWs.addEventListener("message", (clientEvt: MessageEvent) => {
logger().debug(
`test ${connectionType} websocket connection message from client`,
{
dataType: typeof clientEvt.data,
isBlob: clientEvt.data instanceof Blob,
isArrayBuffer: clientEvt.data instanceof ArrayBuffer,
dataConstructor: clientEvt.data?.constructor?.name,
dataStr:
typeof clientEvt.data === "string"
? clientEvt.data.substring(0, 100)
: undefined,
},
);
if (serverWs.readyState === 1) {
// OPEN
// Handle Blob data
if (clientEvt.data instanceof Blob) {
clientEvt.data
.arrayBuffer()
.then((buffer) => {
logger().debug(
"converted client blob to arraybuffer, sending to server",
{
bufferSize: buffer.byteLength,
},
);
serverWs.send(buffer as any);
})
.catch((error) => {
logger().error("failed to convert blob to arraybuffer", {
error,
});
});
} else {
logger().debug("sending client data directly to server", {
dataType: typeof clientEvt.data,
dataLength:
typeof clientEvt.data === "string"
? clientEvt.data.length
: undefined,
});
serverWs.send(clientEvt.data as any);
}
}
});
// Add close handler to close server when client closes
clientWs.addEventListener("close", (clientEvt: CloseEvent) => {
logger().debug(`test ${connectionType} websocket connection closed`);
if (serverWs.readyState !== 3) {
// Not CLOSED
serverWs.close(clientEvt.code, clientEvt.reason);
}
});
// Add error handler
clientWs.addEventListener("error", () => {
logger().debug(`test ${connectionType} websocket connection error`);
if (serverWs.readyState !== 3) {
// Not CLOSED
serverWs.close(1011, "Error in client websocket");
}
});
},
onMessage: (evt: { data: any }) => {
logger().debug("received message from server", {
dataType: typeof evt.data,
isBlob: evt.data instanceof Blob,
isArrayBuffer: evt.data instanceof ArrayBuffer,
dataConstructor: evt.data?.constructor?.name,
dataStr:
typeof evt.data === "string" ? evt.data.substring(0, 100) : undefined,
});
// Forward messages from server websocket to client websocket
if (clientWs.readyState === 1) {
// OPEN
// Handle Blob data
if (evt.data instanceof Blob) {
evt.data
.arrayBuffer()
.then((buffer) => {
logger().debug("converted blob to arraybuffer, sending", {
bufferSize: buffer.byteLength,
});
clientWs.send(buffer);
})
.catch((error) => {
logger().error("failed to convert blob to arraybuffer", {
error,
});
});
} else {
logger().debug("sending data directly", {
dataType: typeof evt.data,
dataLength:
typeof evt.data === "string" ? evt.data.length : undefined,
});
clientWs.send(evt.data);
}
}
},
onClose: (
event: {
wasClean: boolean;
code: number;
reason: string;
},
serverWs: WSContext,
) => {
logger().debug(`server ${connectionType} websocket closed`, {
wasClean: event.wasClean,
code: event.code,
reason: event.reason,
});
// HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state
// https://github.com/cloudflare/workerd/issues/2569
serverWs.close(1000, "hack_force_close");
// Close the client websocket when the server websocket closes
if (
clientWs &&
clientWs.readyState !== clientWs.CLOSED &&
clientWs.readyState !== clientWs.CLOSING
) {
// Don't pass code/message since this may affect how close events are triggered
clientWs.close(1000, event.reason);
}
},
onError: (error: unknown) => {
logger().error(`error in server ${connectionType} websocket`, { error });
// Close the client websocket on error
if (
clientWs &&
clientWs.readyState !== clientWs.CLOSED &&
clientWs.readyState !== clientWs.CLOSING
) {
clientWs.close(1011, "Error in server websocket");
}
},
};
}
/**
* Handle SSE connection request
*/
async function handleSseConnectRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
let encoding: Encoding | undefined;
try {
encoding = getRequestEncoding(c.req);
logger().debug("sse connection request received", { encoding });
const params = ConnectRequestSchema.safeParse({
query: getRequestQuery(c),
encoding: c.req.header(HEADER_ENCODING),
connParams: c.req.header(HEADER_CONN_PARAMS),
});
if (!params.success) {
logger().error("invalid connection parameters", {
error: params.error,
});
throw new errors.InvalidRequest(params.error);
}
const query = params.data.query;
// Parse connection parameters for authentication
const connParams = params.data.connParams
? JSON.parse(params.data.connParams)
: undefined;
// Authenticate the request
const authData = await authenticateEndpoint(
c,
driver,
registryConfig,
query,
["connect"],
connParams,
);
// Get the actor ID
const { actorId } = await queryActor(c, query, driver);
invariant(actorId, "Missing actor ID");
logger().debug("sse connection to actor", { actorId });
// Handle based on mode
logger().debug("using custom proxy mode for sse connection");
const url = new URL("http://actor/connect/sse");
// Always build fresh request to prevent forwarding unwanted headers
const proxyRequest = new Request(url);
proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding);
if (params.data.connParams) {
proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams);
}
if (authData) {
proxyRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData));
}
return await driver.proxyRequest(c, proxyRequest, actorId);
} catch (error) {
// If we receive an error during setup, we send the error and close the socket immediately
//
// We have to return the error over SSE since SSE clients cannot read vanilla HTTP responses
const { code, message, metadata } = deconstructError(error, logger(), {
sseEvent: "setup",
});
return streamSSE(c, async (stream) => {
try {
if (encoding) {
// Serialize and send the connection error
const errorMsg: ToClient = {
b: {
e: {
c: code,
m: message,
md: metadata,
},
},
};
// Send the error message to the client
const serialized = serialize(errorMsg, encoding);
await stream.writeSSE({
data:
typeof serialized === "string"
? serialized
: Buffer.from(serialized).toString("base64"),
});
} else {
// We don't know the encoding, send an error and close
await stream.writeSSE({
data: code,
event: "error",
});
}
} catch (serializeError) {
logger().error("failed to send error to sse client", {
error: serializeError,
});
await stream.writeSSE({
data: "internal error during error handling",
event: "error",
});
}
// Stream will exit completely once function exits
});
}
}
/**
* Handle WebSocket connection request
*/
async function handleWebSocketConnectRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
const upgradeWebSocket = runConfig.getUpgradeWebSocket?.();
if (!upgradeWebSocket) {
return c.text(
"WebSockets are not enabled for this driver. Use SSE instead.",
400,
);
}
let encoding: Encoding | undefined;
try {
logger().debug("websocket connection request received");
// Parse configuration from Sec-WebSocket-Protocol header
//
// We use this instead of query parameters since this is more secure than
// query parameters. Query parameters often get logged.
//
// Browsers don't support using headers, so this is the only way to
// pass data securely.
const protocols = c.req.header("sec-websocket-protocol");
const { queryRaw, encodingRaw, connParamsRaw } =
parseWebSocketProtocols(protocols);
// Parse query
let queryUnvalidated: unknown;
try {
queryUnvalidated = JSON.parse(queryRaw!);
} catch (error) {
logger().error("invalid query json", { error });
throw new errors.InvalidQueryJSON(error);
}
// Parse conn params
let connParamsUnvalidated: unknown = null;
try {
if (connParamsRaw) {
connParamsUnvalidated = JSON.parse(connParamsRaw!);
}
} catch (error) {
logger().error("invalid conn params", { error });
throw new errors.InvalidParams(
`Invalid params JSON: ${stringifyError(error)}`,
);
}
// We can't use the standard headers with WebSockets
//
// All other information will be sent over the socket itself, since that data needs to be E2EE
const params = ConnectWebSocketRequestSchema.safeParse({
query: queryUnvalidated,
encoding: encodingRaw,
connParams: connParamsUnvalidated,
});
if (!params.success) {
logger().error("invalid connection parameters", {
error: params.error,
});
throw new errors.InvalidRequest(params.error);
}
encoding = params.data.encoding;
// Authenticate endpoint
const authData = await authenticateEndpoint(
c,
driver,
registryConfig,
params.data.query,
["connect"],
connParamsRaw,
);
// Get the actor ID
const { actorId } = await queryActor(c, params.data.query, driver);
logger().debug("found actor for websocket connection", {
actorId,
});
invariant(actorId, "missing actor id");
// Proxy the WebSocket connection to the actor
//
// The proxyWebSocket handler will:
// 1. Validate the WebSocket upgrade request
// 2. Forward the request to the actor with the appropriate path
// 3. Handle the WebSocket pair and proxy messages between client and actor
return await driver.proxyWebSocket(
c,
PATH_CONNECT_WEBSOCKET,
actorId,
params.data.encoding,
params.data.connParams,
authData,
);
} catch (error) {
// If we receive an error during setup, we send the error and close the socket immediately
//
// We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses
const { code, message, metadata } = deconstructError(error, logger(), {
wsEvent: "setup",
});
return await upgradeWebSocket(() => ({
onOpen: (_evt: unknown, ws: WSContext) => {
if (encoding) {
try {
// Serialize and send the connection error
const errorMsg: ToClient = {
b: {
e: {
c: code,
m: message,
md: metadata,
},
},
};
// Send the error message to the client
const serialized = serialize(errorMsg, encoding);
ws.send(serialized);
// Close the connection with an error code
ws.close(1011, code);
} catch (serializeError) {
logger().error("failed to send error to websocket client", {
error: serializeError,
});
ws.close(1011, "internal error during error handling");
}
} else {
// We don't know the encoding so we send what we can
ws.close(1011, code);
}
},
}))(c, noopNext());
}
}
/**
* Handle a connection message request to a actor
*
* There is no authentication handler on this request since the connection
* token is used to authenticate the message.
*/
async function handleMessageRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
logger().debug("connection message request received");
try {
const params = ConnMessageRequestSchema.safeParse({
actorId: c.req.header(HEADER_ACTOR_ID),
connId: c.req.header(HEADER_CONN_ID),
encoding: c.req.header(HEADER_ENCODING),
connToken: c.req.header(HEADER_CONN_TOKEN),
});
if (!params.success) {
logger().error("invalid connection parameters", {
error: params.error,
});
throw new errors.InvalidRequest(params.error);
}
const { actorId, connId, encoding, connToken } = params.data;
// TODO: This endpoint can be used to exhause resources (DoS attack) on an actor if you know the actor ID:
// 1. Get the actor ID (usually this is reasonably secure, but we don't assume actor ID is sensitive)
// 2. Spam messages to the actor (the conn token can be invalid)
// 3. The actor will be exhausted processing messages — even if the token is invalid
//
// The solution is we need to move the authorization of the connection token to this request handler
// AND include the actor ID in the connection token so we can verify that it has permission to send
// a message to that actor. This would require changing the token to a JWT so we can include a secure
// payload, but this requires managing a private key & managing key rotations.
//
// All other solutions (e.g. include the actor name as a header or include the actor name in the actor ID)
// have exploits that allow the caller to send messages to arbitrary actors.
//
// Currently, we assume this is not a critical problem because requests will likely get rate
// limited before enough messages are passed to the actor to exhaust resources.
const url = new URL("http://actor/connections/message");
// Always build fresh request to prevent forwarding unwanted headers
const proxyRequest = new Request(url, {
method: "POST",
body: c.req.raw.body,
duplex: "half",
});
proxyRequest.headers.set(HEADER_ENCODING, encoding);
proxyRequest.headers.set(HEADER_CONN_ID, connId);
proxyRequest.headers.set(HEADER_CONN_TOKEN, connToken);
return await driver.proxyRequest(c, proxyRequest, actorId);
} catch (error) {
logger().error("error proxying connection message", { error });
// Use ProxyError if it's not already an ActorError
if (!errors.ActorError.isActorError(error)) {
throw new errors.ProxyError("connection message", error);
} else {
throw error;
}
}
}
/**
* Handle an action request to a actor
*/
async function handleActionRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
try {
const actionName = c.req.param("action");
logger().debug("action call received", { actionName });
const params = ConnectRequestSchema.safeParse({
query: getRequestQuery(c),
encoding: c.req.header(HEADER_ENCODING),
connParams: c.req.header(HEADER_CONN_PARAMS),
});
if (!params.success) {
logger().error("invalid connection parameters", {
error: params.error,
});
throw new errors.InvalidRequest(params.error);
}
// Parse connection parameters for authentication
const connParams = params.data.connParams
? JSON.parse(params.data.connParams)
: undefined;
// Authenticate the request
const authData = await authenticateEndpoint(
c,
driver,
registryConfig,
params.data.query,
["action"],
connParams,
);
// Get the actor ID
const { actorId } = await queryActor(c, params.data.query, driver);
logger().debug("found actor for action", { actorId });
invariant(actorId, "Missing actor ID");
const url = new URL(
`http://actor/action/${encodeURIComponent(actionName)}`,
);
// Always build fresh request to prevent forwarding unwanted headers
const proxyRequest = new Request(url, {
method: "POST",
body: c.req.raw.body,
});
proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding);
if (params.data.connParams) {
proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams);
}
if (authData) {
proxyRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData));
}
return await driver.proxyRequest(c, proxyRequest, actorId);
} catch (error) {
logger().error("error in action handler", { error: stringifyError(error) });
// Use ProxyError if it's not already an ActorError
if (!errors.ActorError.isActorError(error)) {
throw new errors.ProxyError("Action call", error);
} else {
throw error;
}
}
}
/**
* Handle the resolve request to get a actor ID from a query
*/
async function handleResolveRequest(
c: HonoContext,
registryConfig: RegistryConfig,
driver: ManagerDriver,
): Promise<Response> {
const encoding = getRequestEncoding(c.req);
logger().debug("resolve request encoding", { encoding });
const params = ResolveRequestSchema.safeParse({
query: getRequestQuery(c),
connParams: c.req.header(HEADER_CONN_PARAMS),
});
if (!params.success) {
logger().error("invalid connection parameters", {
error: params.error,
});
throw new errors.InvalidRequest(params.error);
}
// Parse connection parameters for authentication
const connParams = params.data.connParams
? JSON.parse(params.data.connParams)
: undefined;
const query = params.data.query;
// Authenticate the request
await authenticateEndpoint(c, driver, registryConfig, query, [], connParams);
// Get the actor ID
const { actorId } = await queryActor(c, query, driver);
logger().debug("resolved actor", { actorId });
invariant(actorId, "Missing actor ID");
// Format response according to protocol
const response: protoHttpResolve.ResolveResponse = {
i: actorId,
};
const serialized = serialize(response, encoding);
return c.body(serialized);
}
/**
* Handle raw HTTP requests to an actor
*/
async function handleRawHttpRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
try {
const subpath = c.req.path.split("/raw/http/")[1] || "";
logger().debug("raw http request received", { subpath });
// Get actor query from header (consistent with other endpoints)
const queryHeader = c.req.header(HEADER_ACTOR_QUERY);
if (!queryHeader) {
throw new errors.InvalidRequest("Missing actor query header");
}
const query: ActorQuery = JSON.parse(queryHeader);
// Parse connection parameters for authentication
const connParamsHeader = c.req.header(HEADER_CONN_PARAMS);
const connParams = connParamsHeader
? JSON.parse(connParamsHeader)
: undefined;
// Authenticate the request
const authData = await authenticateEndpoint(
c,
driver,
registryConfig,
query,
["action"],
connParams,
);
// Get the actor ID
const { actorId } = await queryActor(c, query, driver);
logger().debug("found actor for raw http", { actorId });
invariant(actorId, "Missing actor ID");
// Preserve the original URL's query parameters
const originalUrl = new URL(c.req.url);
const url = new URL(
`http://actor/raw/http/${subpath}${originalUrl.search}`,
);
// Forward the request to the actor
const proxyRequest = new Request(url, {
method: c.req.method,
headers: c.req.raw.headers,
body: c.req.raw.body,
});
logger().debug("rewriting http url", {
from: c.req.url,
to: proxyRequest.url,
});
// Forward conn params if provided
if (connParams) {
proxyRequest.headers.set(HEADER_CONN_PARAMS, JSON.stringify(connParams));
}
// Forward auth data to actor
if (authData) {
proxyRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData));
}
return await driver.proxyRequest(c, proxyRequest, actorId);
} catch (error) {
logger().error("error in raw http handler", {
error: stringifyError(error),
});
// Use ProxyError if it's not already an ActorError
if (!errors.ActorError.isActorError(error)) {
throw new errors.ProxyError("Raw HTTP request", error);
} else {
throw error;
}
}
}
/**
* Handle raw WebSocket requests to an actor
*/
async function handleRawWebSocketRequest(
c: HonoContext,
registryConfig: RegistryConfig,
runConfig: RunConfig,
driver: ManagerDriver,
): Promise<Response> {
const upgradeWebSocket = runConfig.getUpgradeWebSocket?.();
if (!upgradeWebSocket) {
return c.text("WebSockets are not enabled for this driver.", 400);
}
try {
const subpath = c.req.path.split("/raw/websocket/")[1] || "";
logger().debug("raw websocket request received", { subpath });
// Parse protocols from Sec-WebSocket-Protocol header
const protocols = c.req.header("sec-websocket-protocol");
const {
queryRaw: queryFromProtocol,
connParamsRaw: connParamsFromProtocol,
} = parseWebSocketProtocols(protocols);
if (!queryFromProtocol) {
throw new errors.InvalidRequest("Missing query in WebSocket protocol");
}
const query = JSON.parse(queryFromProtocol);
// Parse connection parameters from protocol
let connParams: unknown;
if (connParamsFromProtocol) {
connParams = JSON.parse(connParamsFromProtocol);
}
// Authenticate the request
const authData = await authenticateEndpoint(
c,
driver,
registryConfig,
query,
["action"],
connParams,
);
// Get the actor ID
const { actorId } = await queryActor(c, query, driver);
logger().debug("found actor for raw websocket", { actorId });
invariant(actorId, "Missing actor ID");
logger().debug("using custom proxy mode for raw websocket");
// Preserve the original URL's query parameters
const originalUrl = new URL(c.req.url);
const proxyPath = `${PATH_RAW_WEBSOCKET_PREFIX}${subpath}${originalUrl.search}`;
logger().debug("manager router proxyWebSocket", {
originalUrl: c.req.url,
subpath,
search: originalUrl.search,
proxyPath,
});
// For raw WebSocket, we need to use proxyWebSocket instead of proxyRequest
return await driver.proxyWebSocket(
c,
proxyPath,
actorId,
"json", // Default encoding for raw WebSocket
connParams,
authData,
);
} catch (error) {
// If we receive an error during setup, we send the error and close the socket immediately
//
// We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses
const { code } = deconstructError(error, logger(), {
wsEvent: "setup",
});
return await upgradeWebSocket(() => ({
onOpen: (_evt: unknown, ws: WSContext) => {
// Close with message so we can see the error on the client
ws.close(1011, code);
},
}))(c, noopNext());
}
}
function universalActorProxy({
registryConfig,
runConfig,
driver,
}: {
registryConfig: RegistryConfig;
runConfig: RunConfig;
driver: ManagerDriver;
}): MiddlewareHandler {
return async (c, next) => {
if (c.req.header("upgrade") === "websocket") {
return handleRawWebSocketRequest(c, registryConfig, runConfig, driver);
} else {
const queryHeader = c.req.header(HEADER_ACTOR_QUERY);
if (!queryHeader) {
throw new errors.InvalidRequest("Missing actor query header");
}
const query = ActorQuerySchema.parse(JSON.parse(queryHeader));
const { actorId } = await queryActor(c, query, driver);
const url = new URL(c.req.url);
url.hostname = "actor";
url.pathname = url.pathname
.replace(new RegExp(`^${runConfig.basePath}`, ""), "")
.replace(/^\/registry\/actors/, "")
.replace(/^\/actors/, ""); // Remove /registry prefix if present
const proxyRequest = new Request(url, {
method: c.req.method,
headers: c.req.raw.headers,
body: c.req.raw.body,
});
return await driver.proxyRequest(c, proxyRequest, actorId);
}
};
}