UNPKG

trpc-uwebsockets

Version:

tRPC adapter for uWebSockets.js server

682 lines (611 loc) 17.8 kB
import { type AnyRouter, type CreateContextCallback, callTRPCProcedure, getErrorShape, getTRPCErrorFromUnknown, inferRouterContext, isTrackedEnvelope, transformTRPCResponse, TRPCError, } from '@trpc/server'; import { NodeHTTPCreateContextFnOptions } from '@trpc/server/adapters/node-http'; import { type BaseHandlerOptions, type TRPCRequestInfo, parseConnectionParamsFromUnknown, } from '@trpc/server/http'; import { isObservable, observableToAsyncIterable, } from '@trpc/server/observable'; import { type TRPCClientOutgoingMessage, type TRPCConnectionParamsMessage, type TRPCReconnectNotification, type TRPCResponseMessage, type TRPCResultMessage, parseTRPCMessage, } from '@trpc/server/rpc'; import { type MaybePromise, isAsyncIterable, isObject, iteratorResource, run, Unpromise, } from '@trpc/server/unstable-core-do-not-import'; import { URL } from 'url'; import type { TemplatedApp, WebSocket, WebSocketBehavior, } from 'uWebSockets.js'; import { createURL, decorateHttpResponse, HttpResponseDecorated, uWsToRequestNoBody, } from './fetchCompat'; type RemoveFunctions<T> = { // eslint-disable-next-line @typescript-eslint/no-unsafe-function-type [K in keyof T as NonNullable<T[K]> extends Function ? never : K]: T[K]; }; /** * A structure holding settings for a WebSocket handler. */ export type WebSocketBehaviorOptions = RemoveFunctions<WebSocketBehavior<any>>; export type WebSocketConnection = WebSocket<WebsocketData>; // following packages/server/src/adapters/ws.ts /** * @public */ export type CreateWSSContextFnOptions = NodeHTTPCreateContextFnOptions< Request, HttpResponseDecorated // WebSocketConnection > & { client: WebSocketConnection; }; export type CreateWSSContextFn<TRouter extends AnyRouter> = ( opts: CreateWSSContextFnOptions ) => MaybePromise<inferRouterContext<TRouter>>; export type WSConnectionHandlerOptions<TRouter extends AnyRouter> = BaseHandlerOptions<TRouter, Request> & CreateContextCallback< inferRouterContext<TRouter>, CreateWSSContextFn<TRouter> >; export type WebsocketsKeepAlive = { /** * Enable heartbeat messages * @default false */ enabled: boolean; /** * Heartbeat interval in milliseconds * @default 30_000 */ pingMs?: number | undefined; /** * Terminate the WebSocket if no pong is received after this many milliseconds * @default 5_000 */ pongWaitMs?: number | undefined; }; /** * WebSockets handler definition */ export type WebsocketsHandlerOptions<TRouter extends AnyRouter> = WSConnectionHandlerOptions<TRouter> & { /** * Url path prefix where the tRPC server will be registered. * @default '' */ prefix?: string | undefined; /** * Specify if SSL is used. Set to true if you are using SSLApp or if the server is served behind SSL reverse proxy. * @default false */ ssl?: boolean | undefined; keepAlive?: WebsocketsKeepAlive | undefined; /** * Disable responding to ping messages from the client * **Not recommended** - this is mainly used for testing * @default false */ dangerouslyDisablePong?: boolean | undefined; /** * uWebSockets.js WebSocket hander settings */ uWsBehaviorOptions?: WebSocketBehaviorOptions | undefined; }; interface Completer { promise: Promise<void>; resolve: () => void; reject: (error: Error) => void; } function createCompleter(): Completer { let resolve: () => void; let reject: (error: Error) => void; const promise = new Promise<void>((res, rej) => { resolve = res; reject = rej; }); return { promise, resolve: resolve!, reject: reject!, }; } // data bound internally on each client type WebsocketData = { req: Request; clientSubscriptions: Map<number | string, AbortController>; abortController: AbortController; ctx: inferRouterContext<AnyRouter> | undefined; /** * inside-out promise that resolves when context is ready. * * - when this is null, the context resolution will be started * - otherwise all requests must await context resolution */ ctxCompleter: Completer | null; keepAlive: KeepAliver | null; url: URL; }; export function getWSConnectionHandler<TRouter extends AnyRouter>( opts: WebsocketsHandlerOptions<TRouter>, allClients: Set<WebSocketConnection> ): WebSocketBehavior<WebsocketData> { const { createContext, router } = opts; const { transformer } = router._def._config; function respond( client: WebSocketConnection, untransformedJSON: TRPCResponseMessage ) { try { client.send( JSON.stringify( transformTRPCResponse(router._def._config, untransformedJSON) ) ); } catch { // client.send can throw if connection is already closed. // happens when client forcefully terminates the connection // and server is sending keepalive messages } } function getConnectionParams( msgStr: string ): TRPCRequestInfo['connectionParams'] { let msg; try { msg = JSON.parse(msgStr) as TRPCConnectionParamsMessage; if (!isObject(msg)) { throw new Error('Message was not an object'); } } catch (cause) { throw new TRPCError({ code: 'PARSE_ERROR', message: `Malformed TRPCConnectionParamsMessage`, cause, }); } const connectionParams = parseConnectionParamsFromUnknown(msg.data); return connectionParams; } async function handleRequest( client: WebSocket<WebsocketData>, msg: TRPCClientOutgoingMessage ) { const { clientSubscriptions, ctx, req } = client.getUserData(); const { id, jsonrpc } = msg; /* istanbul ignore next -- @preserve */ if (id === null) { throw new TRPCError({ code: 'BAD_REQUEST', message: '`id` is required', }); } if (msg.method === 'subscription.stop') { clientSubscriptions.get(id)?.abort(); return; } const { path, lastEventId } = msg.params; let { input } = msg.params; const type = msg.method; try { if (lastEventId !== undefined) { if (isObject(input)) { input = { ...input, lastEventId: lastEventId, }; } else { input ??= { lastEventId: lastEventId, }; } } if (ctx === null) { throw new Error('assertion: context should never be null'); } const abortController = new AbortController(); const result = await callTRPCProcedure({ router, path, getRawInput: async () => input, ctx, type, signal: abortController.signal, }); const isIterableResult = isAsyncIterable(result) || isObservable(result); if (type !== 'subscription') { if (isIterableResult) { throw new TRPCError({ code: 'UNSUPPORTED_MEDIA_TYPE', message: `Cannot return an async iterable or observable from a ${type} procedure with WebSockets`, }); } // send the value as data if the method is not a subscription respond(client, { id, jsonrpc, result: { type: 'data', data: result, }, }); return; } if (!isIterableResult) { throw new TRPCError({ message: `Subscription ${path} did not return an observable or a AsyncGenerator`, code: 'INTERNAL_SERVER_ERROR', }); } /* istanbul ignore next -- @preserve */ if (clientSubscriptions.has(id)) { // duplicate request ids for client throw new TRPCError({ message: `Duplicate id ${id}`, code: 'BAD_REQUEST', }); } const iterable = isObservable(result) ? observableToAsyncIterable(result, abortController.signal) : result; run(async () => { await using iterator = iteratorResource(iterable); const abortPromise = new Promise<'abort'>((resolve) => { abortController.signal.onabort = () => resolve('abort'); }); // We need those declarations outside the loop for garbage collection reasons. If they // were declared inside, they would not be freed until the next value is present. let next: | null | TRPCError | Awaited< typeof abortPromise | ReturnType<(typeof iterator)['next']> >; let result: null | TRPCResultMessage<unknown>['result']; while (true) { next = await Unpromise.race([ iterator.next().catch(getTRPCErrorFromUnknown), abortPromise, ]); if (next === 'abort') { await iterator.return?.(); break; } if (next instanceof Error) { const error = getTRPCErrorFromUnknown(next); opts.onError?.({ error, path, type, ctx, req, input }); respond(client, { id, jsonrpc, error: getErrorShape({ config: router._def._config, error, type, path, input, ctx, }), }); break; } if (next.done) { break; } result = { type: 'data', data: next.value, }; if (isTrackedEnvelope(next.value)) { const [id, data] = next.value; result.id = id; result.data = { id, data, }; } respond(client, { id, jsonrpc, result, }); // free up references for garbage collection next = null; result = null; } respond(client, { id, jsonrpc, result: { type: 'stopped', }, }); clientSubscriptions.delete(id); }).catch((cause) => { const error = getTRPCErrorFromUnknown(cause); opts.onError?.({ error, path, type, ctx, req, input }); respond(client, { id, jsonrpc, error: getErrorShape({ config: router._def._config, error, type, path, input, ctx, }), }); abortController.abort(); }); clientSubscriptions.set(id, abortController); respond(client, { id, jsonrpc, result: { type: 'started', }, }); } catch (cause) /* istanbul ignore next -- @preserve */ { // procedure threw an error const error = getTRPCErrorFromUnknown(cause); opts.onError?.({ error, path, type, ctx, req, input }); respond(client, { id, jsonrpc, error: getErrorShape({ config: router._def._config, error, type, path, input, ctx, }), }); } } return { sendPingsAutomatically: opts.uWsBehaviorOptions?.sendPingsAutomatically, // could this be enabled? closeOnBackpressureLimit: opts.uWsBehaviorOptions?.closeOnBackpressureLimit, compression: opts.uWsBehaviorOptions?.compression, maxBackpressure: opts.uWsBehaviorOptions?.maxBackpressure, maxPayloadLength: opts.uWsBehaviorOptions?.maxPayloadLength, maxLifetime: opts.uWsBehaviorOptions?.maxLifetime, idleTimeout: opts.uWsBehaviorOptions?.idleTimeout, upgrade(res, req, context) { const resDecorated = decorateHttpResponse(res); res.onAborted(() => { resDecorated.aborted = true; }); const reqFetch = uWsToRequestNoBody(req, resDecorated); const secWebSocketKey = req.getHeader('sec-websocket-key'); const secWebSocketProtocol = req.getHeader('sec-websocket-protocol'); const secWebSocketExtensions = req.getHeader('sec-websocket-extensions'); const clientSubscriptions = new Map<number | string, AbortController>(); const abortController = new AbortController(); const data: WebsocketData = { clientSubscriptions, abortController, req: reqFetch, ctx: undefined, ctxCompleter: null, keepAlive: null, url: createURL(req, resDecorated.sll ? 'wss' : 'ws'), }; res.upgrade( data, secWebSocketKey, secWebSocketProtocol, secWebSocketExtensions, context ); }, async open(client) { allClients.add(client); if (opts.keepAlive?.enabled) { const { pingMs, pongWaitMs } = opts.keepAlive; const data = client.getUserData(); data.keepAlive = handleKeepAlive(client, pingMs, pongWaitMs); } }, async message(client, rawMsg) { const data = client.getUserData(); if (data.keepAlive) { data.keepAlive.onMessage(); } const msgStr = Buffer.from(rawMsg).toString(); if (msgStr === 'PONG') { return; } if (msgStr === 'PING') { if (!opts.dangerouslyDisablePong) { client.send('PONG'); } return; } if (data.ctxCompleter == null) { data.ctxCompleter = createCompleter(); const useConnectionParams = data.url.searchParams.get('connectionParams') === '1'; try { data.ctx = await createContext?.({ req: data.req, res: client as unknown as HttpResponseDecorated, client: client, info: { connectionParams: useConnectionParams ? getConnectionParams(msgStr) : null, calls: [], isBatchCall: false, accept: null, type: 'unknown', signal: data.abortController.signal, url: data.url, }, }); data.ctxCompleter.resolve(); } catch (cause) { const error = getTRPCErrorFromUnknown(cause); opts.onError?.({ ctx: data.ctx, error: error, input: undefined, path: undefined, type: 'unknown', req: data.req, }); respond(client, { id: null, error: getErrorShape({ config: router._def._config, error, type: 'unknown', path: undefined, input: undefined, ctx: data.ctx, }), }); data.ctxCompleter.reject(error); // close in next tick (globalThis.setImmediate ?? globalThis.setTimeout)(() => { client.end(1008); }); } if (useConnectionParams) { // fully consume first message return; } } try { await data.ctxCompleter.promise; } catch { // stop execution of pending requests when context could not be resolved // single error message will be sent return; } try { const msgJSON: unknown = JSON.parse(msgStr); const msgs: unknown[] = Array.isArray(msgJSON) ? msgJSON : [msgJSON]; const promises = msgs .map((raw) => parseTRPCMessage(raw, transformer)) .map((msg) => { return handleRequest(client, msg); }); await Promise.all(promises); } catch (cause) { const error = new TRPCError({ code: 'PARSE_ERROR', cause, }); respond(client, { id: null, error: getErrorShape({ config: router._def._config, error, type: 'unknown', path: undefined, input: undefined, ctx: undefined, }), }); } }, close(client) { const { clientSubscriptions, abortController, keepAlive } = client.getUserData(); if (keepAlive) { keepAlive.onClose(); } for (const sub of clientSubscriptions.values()) { sub.abort(); } clientSubscriptions.clear(); abortController.abort(); allClients.delete(client); }, }; } export function applyWebsocketHandler<TRouter extends AnyRouter>( app: TemplatedApp, opts: WebsocketsHandlerOptions<TRouter> ) { const allClients = new Set<WebSocket<WebsocketData>>(); const behavior = getWSConnectionHandler(opts, allClients); const prefix = opts.prefix ?? ''; app.ws(prefix, behavior); return { broadcastReconnectNotification: () => { const response: TRPCReconnectNotification = { id: null, method: 'reconnect', }; const data = JSON.stringify(response); for (const client of allClients) { client.send(data); } }, }; } type KeepAliver = { onMessage: () => void; onClose: () => void; }; export function handleKeepAlive( client: WebSocketConnection, pingMs = 30_000, pongWaitMs = 5_000 ): KeepAliver { let timeout: NodeJS.Timeout | undefined = undefined; let ping: NodeJS.Timeout | undefined = undefined; const schedulePing = () => { const scheduleTimeout = () => { timeout = setTimeout(() => { client.close(); }, pongWaitMs); }; ping = setTimeout(() => { client.send('PING'); scheduleTimeout(); }, pingMs); }; schedulePing(); return { onMessage() { clearTimeout(ping); clearTimeout(timeout); schedulePing(); }, onClose() { clearTimeout(ping); clearTimeout(timeout); }, }; }