UNPKG

@ws-kit/zod

Version:

Zod validator adapter for WS-Kit with runtime schema validation and full TypeScript inference

293 lines 15.4 kB
// SPDX-FileCopyrightText: 2025-present Kriasoft // SPDX-License-Identifier: MIT import { getRouteIndex } from "@ws-kit/core"; import { getKind, getRouterPluginAPI, getSchemaOpts, typeOf, } from "@ws-kit/core/internal"; import { definePlugin } from "@ws-kit/core/plugin"; import { withMessaging as coreWithMessaging, withRpc as coreWithRpc, } from "@ws-kit/plugins"; import { getZodPayload, validatePayload } from "./internal.js"; /** * Helper to format Zod errors for better DX. * @internal */ function formatValidationError(error) { if (error.flatten) { const flat = error.flatten(); const issues = [ ...(flat.formErrors || []), ...Object.entries(flat.fieldErrors || {}).flatMap(([field, msgs]) => (msgs || []).map((m) => `${field}: ${m}`)), ]; return issues.length > 0 ? issues.join("; ") : JSON.stringify(error); } return JSON.stringify(error); } /** * Helper to resolve effective options, preferring per-schema over plugin defaults. * @internal */ function resolveOptions(schemaOpts, pluginOpts) { return { validateOutgoing: schemaOpts?.validateOutgoing ?? pluginOpts.validateOutgoing ?? true, }; } export function withZod(options) { const pluginOpts = { validateOutgoing: options?.validateOutgoing ?? true, onValidationError: options?.onValidationError, }; return definePlugin((router) => { // Step 1: Apply core messaging and RPC plugins first // These provide ctx.send(), ctx.reply(), ctx.error(), ctx.progress() router.plugin(coreWithMessaging()); router.plugin(coreWithRpc()); // Step 2: Get plugin API for registering validation middleware const api = getRouterPluginAPI(router); // Step 3: Inject validation middleware that validates root message and enriches context // This runs BEFORE core messaging/RPC enhancers (lower priority) // so that ctx.payload is available for the messaging methods to use router.use(async (ctx, next) => { // Skip validation for system lifecycle events ($ws:open, $ws:close) // These are handled by router.onOpen/onClose and don't need payload validation if (typeof ctx.type === "string" && ctx.type.startsWith("$ws:")) { await next(); return; } // Capture lifecycle for use in error handlers const lifecycle = api.getLifecycle(); // Get the schema from route index by looking up the type const routeIndex = getRouteIndex(router); const schemaInfo = routeIndex.get(ctx.type); if (schemaInfo) { const schema = schemaInfo.schema; const enhCtx = ctx; // If schema is a Zod object (has safeParse), validate the full root message if (typeof schema?.safeParse === "function") { // Get per-schema options and resolve effective options const schemaOpts = getSchemaOpts(schema); resolveOptions(schemaOpts, pluginOpts); // Construct normalized inbound message const inboundMessage = { type: ctx.type, meta: enhCtx.meta || {}, ...(enhCtx.payload !== undefined ? { payload: enhCtx.payload } : {}), }; // Validate against root schema (enforces strict type, meta, payload) // Always use safeParse for consistent error handling. // Coercion is controlled by schema design (z.coerce.*), not runtime flags. const result = schema.safeParse(inboundMessage); if (!result.success) { // Create validation error and route to error sink const validationError = new Error(`Validation failed for ${ctx.type}: ${formatValidationError(result.error)}`); validationError.code = "VALIDATION_ERROR"; validationError.details = result.error; // Call custom hook if provided, otherwise route to error handler if (pluginOpts.onValidationError) { await pluginOpts.onValidationError(validationError, { type: ctx.type, direction: "inbound", payload: enhCtx.payload, }); } else { await lifecycle.handleError(validationError, ctx); } return; } // Enrich context with validated payload (extracted from root validation) if (result.data.payload !== undefined) { enhCtx.payload = result.data.payload; } // Stash schema info for later use in reply/progress/send validation const kind = getKind(schemaInfo.schema); // read from DESCRIPTOR symbol const existingWskit = enhCtx.__wskit || {}; Object.defineProperty(enhCtx, "__wskit", { enumerable: false, configurable: true, value: { ...existingWskit, ...(kind !== undefined && { kind }), request: schema, response: schema.response, }, }); } } // Continue with enriched context await next(); }); // Step 4: Register context enhancer to add outbound validation capability // This wraps the core messaging/RPC methods to optionally validate outgoing payloads api.addContextEnhancer((ctx) => { const enhCtx = ctx; // Capture lifecycle for use in nested functions const lifecycle = api.getLifecycle(); // Helper: validate outgoing message (full root validation) const validateOutgoingPayload = async (schema, payload) => { // Get per-schema options and resolve effective options for this schema const schemaOpts = typeof schema === "object" ? getSchemaOpts(schema) : undefined; const eff = resolveOptions(schemaOpts, pluginOpts); if (!eff.validateOutgoing) { return payload; } const schemaObj = schema; // If schema has safeParse, validate full root message if (typeof schemaObj?.safeParse === "function") { // Construct outbound message const outboundMessage = { type: typeOf(schemaObj, schema), meta: {}, ...(payload !== undefined ? { payload } : {}), }; const result = schemaObj.safeParse(outboundMessage); if (!result.success) { const validationError = new Error(`Outbound validation failed for ${typeOf(schemaObj, schema) || "unknown"}: ${formatValidationError(result.error)}`); validationError.code = "OUTBOUND_VALIDATION_ERROR"; validationError.details = result.error; if (pluginOpts.onValidationError) { await pluginOpts.onValidationError(validationError, { type: typeOf(schemaObj, schema) || "unknown", direction: "outbound", payload, }); } else { await lifecycle.handleError(validationError, ctx); } throw validationError; } return result.data.payload ?? payload; } // Fallback for schemas that don't have safeParse (e.g., legacy message descriptors). // In normal usage with message() and rpc() builders, this branch is never reached. // This path exists for edge cases where schemas are constructed manually. const payloadSchema = getZodPayload(schema); if (!payloadSchema) { // Non-Zod schema without payload metadata—skip validation return payload; } const result = validatePayload(payload, payloadSchema); if (!result.success) { const validationError = new Error(`Outbound validation failed for ${typeOf(schemaObj, schema) || "unknown"}: ${formatValidationError(result.error)}`); validationError.code = "OUTBOUND_VALIDATION_ERROR"; validationError.details = result.error; if (pluginOpts.onValidationError) { await pluginOpts.onValidationError(validationError, { type: typeOf(schemaObj, schema) || "unknown", direction: "outbound", payload, }); } else { await lifecycle.handleError(validationError, ctx); } throw validationError; } return result.data ?? payload; }; // Helper: validate payload against RPC response schema const validateProgressPayload = async (responseSchema, progressPayload) => { // Get per-schema options and resolve effective options const schemaOpts = getSchemaOpts(responseSchema); const eff = resolveOptions(schemaOpts, pluginOpts); if (!eff.validateOutgoing) { return progressPayload; } // Get the payload schema from the response message schema const schemaObj = responseSchema; if (typeof schemaObj?.safeParse === "function") { // Construct a temporary message to validate the payload shape const tempMessage = { type: schemaObj.responseType || typeOf(schemaObj), meta: {}, ...(progressPayload !== undefined ? { payload: progressPayload } : {}), }; const result = schemaObj.safeParse(tempMessage); if (!result.success) { const validationError = new Error(`Progress validation failed for ${ctx.type}: ${formatValidationError(result.error)}`); validationError.code = "PROGRESS_VALIDATION_ERROR"; validationError.details = result.error; if (pluginOpts.onValidationError) { await pluginOpts.onValidationError(validationError, { type: "$ws:rpc-progress", direction: "outbound", payload: progressPayload, }); } else { await lifecycle.handleError(validationError, ctx); } throw validationError; } return result.data.payload ?? progressPayload; } return progressPayload; }; // Wrap extension methods with validation. // Core plugins expose delegates on ctx that call through to extensions, // so wrapping the extension method is sufficient - no ctx assignment needed. // This avoids "enhancer overwrote ctx properties" warnings. const messagingExt = ctx.extensions.get("messaging"); const rpcExt = ctx.extensions.get("rpc"); const pubsubExt = ctx.extensions.get("pubsub"); if (messagingExt?.send) { // Wrap send() with outbound validation const coreSend = messagingExt.send; messagingExt.send = async (schema, payload, opts) => { const validatedPayload = await validateOutgoingPayload(schema, payload); return coreSend(schema, validatedPayload, opts); }; } if (rpcExt?.reply) { // Wrap reply() with outbound validation const coreReply = rpcExt.reply; rpcExt.reply = async (payload, opts) => { const wskit = enhCtx.__wskit; if (wskit?.response) { const schemaOpts = getSchemaOpts(wskit.response); const eff = resolveOptions(schemaOpts, pluginOpts); const shouldValidate = opts?.validate ?? eff.validateOutgoing; if (shouldValidate) { const validatedPayload = await validateOutgoingPayload(wskit.response, payload); return coreReply(validatedPayload, opts); } } return coreReply(payload, opts); }; } if (rpcExt?.progress) { // Wrap progress() with outbound validation const coreProgress = rpcExt.progress; rpcExt.progress = async (payload, opts) => { const wskit = enhCtx.__wskit; if (wskit?.response) { const schemaOpts = getSchemaOpts(wskit.response); const eff = resolveOptions(schemaOpts, pluginOpts); const shouldValidate = opts?.validate ?? eff.validateOutgoing; if (shouldValidate) { const validatedPayload = await validateProgressPayload(wskit.response, payload); return coreProgress(validatedPayload, opts); } } return coreProgress(payload, opts); }; } if (pubsubExt?.publish) { // Wrap publish() with outbound validation const corePublish = pubsubExt.publish; pubsubExt.publish = async (topic, schema, payload, opts) => { const validatedPayload = await validateOutgoingPayload(schema, payload); return corePublish(topic, schema, validatedPayload, opts); }; } }, { priority: 100 }); // Return the plugin API extensions with capability marker and rpc(). return { validation: true, __caps: { validation: true }, }; }); } //# sourceMappingURL=plugin.js.map