@ws-kit/zod
Version:
Zod validator adapter for WS-Kit with runtime schema validation and full TypeScript inference
293 lines • 15.4 kB
JavaScript
// SPDX-FileCopyrightText: 2025-present Kriasoft
// SPDX-License-Identifier: MIT
import { getRouteIndex } from "@ws-kit/core";
import { getKind, getRouterPluginAPI, getSchemaOpts, typeOf, } from "@ws-kit/core/internal";
import { definePlugin } from "@ws-kit/core/plugin";
import { withMessaging as coreWithMessaging, withRpc as coreWithRpc, } from "@ws-kit/plugins";
import { getZodPayload, validatePayload } from "./internal.js";
/**
* Helper to format Zod errors for better DX.
* @internal
*/
function formatValidationError(error) {
if (error.flatten) {
const flat = error.flatten();
const issues = [
...(flat.formErrors || []),
...Object.entries(flat.fieldErrors || {}).flatMap(([field, msgs]) => (msgs || []).map((m) => `${field}: ${m}`)),
];
return issues.length > 0 ? issues.join("; ") : JSON.stringify(error);
}
return JSON.stringify(error);
}
/**
* Helper to resolve effective options, preferring per-schema over plugin defaults.
* @internal
*/
function resolveOptions(schemaOpts, pluginOpts) {
return {
validateOutgoing: schemaOpts?.validateOutgoing ?? pluginOpts.validateOutgoing ?? true,
};
}
export function withZod(options) {
const pluginOpts = {
validateOutgoing: options?.validateOutgoing ?? true,
onValidationError: options?.onValidationError,
};
return definePlugin((router) => {
// Step 1: Apply core messaging and RPC plugins first
// These provide ctx.send(), ctx.reply(), ctx.error(), ctx.progress()
router.plugin(coreWithMessaging());
router.plugin(coreWithRpc());
// Step 2: Get plugin API for registering validation middleware
const api = getRouterPluginAPI(router);
// Step 3: Inject validation middleware that validates root message and enriches context
// This runs BEFORE core messaging/RPC enhancers (lower priority)
// so that ctx.payload is available for the messaging methods to use
router.use(async (ctx, next) => {
// Skip validation for system lifecycle events ($ws:open, $ws:close)
// These are handled by router.onOpen/onClose and don't need payload validation
if (typeof ctx.type === "string" && ctx.type.startsWith("$ws:")) {
await next();
return;
}
// Capture lifecycle for use in error handlers
const lifecycle = api.getLifecycle();
// Get the schema from route index by looking up the type
const routeIndex = getRouteIndex(router);
const schemaInfo = routeIndex.get(ctx.type);
if (schemaInfo) {
const schema = schemaInfo.schema;
const enhCtx = ctx;
// If schema is a Zod object (has safeParse), validate the full root message
if (typeof schema?.safeParse === "function") {
// Get per-schema options and resolve effective options
const schemaOpts = getSchemaOpts(schema);
resolveOptions(schemaOpts, pluginOpts);
// Construct normalized inbound message
const inboundMessage = {
type: ctx.type,
meta: enhCtx.meta || {},
...(enhCtx.payload !== undefined
? { payload: enhCtx.payload }
: {}),
};
// Validate against root schema (enforces strict type, meta, payload)
// Always use safeParse for consistent error handling.
// Coercion is controlled by schema design (z.coerce.*), not runtime flags.
const result = schema.safeParse(inboundMessage);
if (!result.success) {
// Create validation error and route to error sink
const validationError = new Error(`Validation failed for ${ctx.type}: ${formatValidationError(result.error)}`);
validationError.code = "VALIDATION_ERROR";
validationError.details = result.error;
// Call custom hook if provided, otherwise route to error handler
if (pluginOpts.onValidationError) {
await pluginOpts.onValidationError(validationError, {
type: ctx.type,
direction: "inbound",
payload: enhCtx.payload,
});
}
else {
await lifecycle.handleError(validationError, ctx);
}
return;
}
// Enrich context with validated payload (extracted from root validation)
if (result.data.payload !== undefined) {
enhCtx.payload = result.data.payload;
}
// Stash schema info for later use in reply/progress/send validation
const kind = getKind(schemaInfo.schema); // read from DESCRIPTOR symbol
const existingWskit = enhCtx.__wskit || {};
Object.defineProperty(enhCtx, "__wskit", {
enumerable: false,
configurable: true,
value: {
...existingWskit,
...(kind !== undefined && { kind }),
request: schema,
response: schema.response,
},
});
}
}
// Continue with enriched context
await next();
});
// Step 4: Register context enhancer to add outbound validation capability
// This wraps the core messaging/RPC methods to optionally validate outgoing payloads
api.addContextEnhancer((ctx) => {
const enhCtx = ctx;
// Capture lifecycle for use in nested functions
const lifecycle = api.getLifecycle();
// Helper: validate outgoing message (full root validation)
const validateOutgoingPayload = async (schema, payload) => {
// Get per-schema options and resolve effective options for this schema
const schemaOpts = typeof schema === "object" ? getSchemaOpts(schema) : undefined;
const eff = resolveOptions(schemaOpts, pluginOpts);
if (!eff.validateOutgoing) {
return payload;
}
const schemaObj = schema;
// If schema has safeParse, validate full root message
if (typeof schemaObj?.safeParse === "function") {
// Construct outbound message
const outboundMessage = {
type: typeOf(schemaObj, schema),
meta: {},
...(payload !== undefined ? { payload } : {}),
};
const result = schemaObj.safeParse(outboundMessage);
if (!result.success) {
const validationError = new Error(`Outbound validation failed for ${typeOf(schemaObj, schema) || "unknown"}: ${formatValidationError(result.error)}`);
validationError.code = "OUTBOUND_VALIDATION_ERROR";
validationError.details = result.error;
if (pluginOpts.onValidationError) {
await pluginOpts.onValidationError(validationError, {
type: typeOf(schemaObj, schema) || "unknown",
direction: "outbound",
payload,
});
}
else {
await lifecycle.handleError(validationError, ctx);
}
throw validationError;
}
return result.data.payload ?? payload;
}
// Fallback for schemas that don't have safeParse (e.g., legacy message descriptors).
// In normal usage with message() and rpc() builders, this branch is never reached.
// This path exists for edge cases where schemas are constructed manually.
const payloadSchema = getZodPayload(schema);
if (!payloadSchema) {
// Non-Zod schema without payload metadata—skip validation
return payload;
}
const result = validatePayload(payload, payloadSchema);
if (!result.success) {
const validationError = new Error(`Outbound validation failed for ${typeOf(schemaObj, schema) || "unknown"}: ${formatValidationError(result.error)}`);
validationError.code = "OUTBOUND_VALIDATION_ERROR";
validationError.details = result.error;
if (pluginOpts.onValidationError) {
await pluginOpts.onValidationError(validationError, {
type: typeOf(schemaObj, schema) || "unknown",
direction: "outbound",
payload,
});
}
else {
await lifecycle.handleError(validationError, ctx);
}
throw validationError;
}
return result.data ?? payload;
};
// Helper: validate payload against RPC response schema
const validateProgressPayload = async (responseSchema, progressPayload) => {
// Get per-schema options and resolve effective options
const schemaOpts = getSchemaOpts(responseSchema);
const eff = resolveOptions(schemaOpts, pluginOpts);
if (!eff.validateOutgoing) {
return progressPayload;
}
// Get the payload schema from the response message schema
const schemaObj = responseSchema;
if (typeof schemaObj?.safeParse === "function") {
// Construct a temporary message to validate the payload shape
const tempMessage = {
type: schemaObj.responseType || typeOf(schemaObj),
meta: {},
...(progressPayload !== undefined
? { payload: progressPayload }
: {}),
};
const result = schemaObj.safeParse(tempMessage);
if (!result.success) {
const validationError = new Error(`Progress validation failed for ${ctx.type}: ${formatValidationError(result.error)}`);
validationError.code = "PROGRESS_VALIDATION_ERROR";
validationError.details = result.error;
if (pluginOpts.onValidationError) {
await pluginOpts.onValidationError(validationError, {
type: "$ws:rpc-progress",
direction: "outbound",
payload: progressPayload,
});
}
else {
await lifecycle.handleError(validationError, ctx);
}
throw validationError;
}
return result.data.payload ?? progressPayload;
}
return progressPayload;
};
// Wrap extension methods with validation.
// Core plugins expose delegates on ctx that call through to extensions,
// so wrapping the extension method is sufficient - no ctx assignment needed.
// This avoids "enhancer overwrote ctx properties" warnings.
const messagingExt = ctx.extensions.get("messaging");
const rpcExt = ctx.extensions.get("rpc");
const pubsubExt = ctx.extensions.get("pubsub");
if (messagingExt?.send) {
// Wrap send() with outbound validation
const coreSend = messagingExt.send;
messagingExt.send = async (schema, payload, opts) => {
const validatedPayload = await validateOutgoingPayload(schema, payload);
return coreSend(schema, validatedPayload, opts);
};
}
if (rpcExt?.reply) {
// Wrap reply() with outbound validation
const coreReply = rpcExt.reply;
rpcExt.reply = async (payload, opts) => {
const wskit = enhCtx.__wskit;
if (wskit?.response) {
const schemaOpts = getSchemaOpts(wskit.response);
const eff = resolveOptions(schemaOpts, pluginOpts);
const shouldValidate = opts?.validate ?? eff.validateOutgoing;
if (shouldValidate) {
const validatedPayload = await validateOutgoingPayload(wskit.response, payload);
return coreReply(validatedPayload, opts);
}
}
return coreReply(payload, opts);
};
}
if (rpcExt?.progress) {
// Wrap progress() with outbound validation
const coreProgress = rpcExt.progress;
rpcExt.progress = async (payload, opts) => {
const wskit = enhCtx.__wskit;
if (wskit?.response) {
const schemaOpts = getSchemaOpts(wskit.response);
const eff = resolveOptions(schemaOpts, pluginOpts);
const shouldValidate = opts?.validate ?? eff.validateOutgoing;
if (shouldValidate) {
const validatedPayload = await validateProgressPayload(wskit.response, payload);
return coreProgress(validatedPayload, opts);
}
}
return coreProgress(payload, opts);
};
}
if (pubsubExt?.publish) {
// Wrap publish() with outbound validation
const corePublish = pubsubExt.publish;
pubsubExt.publish = async (topic, schema, payload, opts) => {
const validatedPayload = await validateOutgoingPayload(schema, payload);
return corePublish(topic, schema, validatedPayload, opts);
};
}
}, { priority: 100 });
// Return the plugin API extensions with capability marker and rpc().
return {
validation: true,
__caps: { validation: true },
};
});
}
//# sourceMappingURL=plugin.js.map