UNPKG

@xtr-dev/zod-rpc

Version:

Simple, type-safe RPC library with Zod validation and automatic TypeScript inference

284 lines 10.5 kB
import { z } from 'zod'; import { RPCError, ValidationError, MethodNotFoundError, TimeoutError } from './errors.js'; import { implementService } from './service.js'; /** * Core communication channel that manages RPC method calls, message routing, and transport connections. * This is the central orchestrator for all RPC communication between clients and servers. * * @example * ```typescript * const channel = new Channel('client-123', 30000); * * // Connect to transport * const transport = createWebSocketTransport('ws://localhost:8080'); * await channel.connect(transport); * * // Call remote method * const result = await channel.invoke( * 'server', * 'user.get', * { userId: '123' }, * inputSchema, * outputSchema * ); * ``` * * @group Core Classes */ export class Channel { channelId; defaultTimeout; methods = new Map(); // targetId:methodId -> method pendingCalls = new Map(); services = new Map(); transports = new Set(); connected = false; constructor(channelId, defaultTimeout = 30000) { this.channelId = channelId; this.defaultTimeout = defaultTimeout; // Channel can now exist without any transports initially } async connect(...transports) { // Add transports to the set and connect them for (const transport of transports) { this.transports.add(transport); transport.onMessage(this.handleMessage.bind(this)); await transport.connect(); } this.connected = true; await this.publishServiceInfo(); } async disconnect() { for (const [traceId, pending] of this.pendingCalls) { clearTimeout(pending.timeout); pending.reject(new TimeoutError(`Connection closed`, traceId)); } this.pendingCalls.clear(); // Disconnect all transports for (const transport of this.transports) { await transport.disconnect(); } this.transports.clear(); this.connected = false; } publishMethod(definition) { const key = `${definition.targetId}:${definition.id}`; this.methods.set(key, definition); } /** * Publish a service implementation to this channel, making all its methods available for RPC calls. * This is a convenience method that combines implementService and publishMethod. * * @template T - Service methods record type * @param service - The service definition with schemas * @param implementation - The actual implementation functions * @param targetId - Optional target identifier (defaults to channelId) * * @example * ```typescript * const channel = new Channel('server'); * * // Instead of: * // const methods = implementService(userService, implementation, 'server'); * // methods.forEach(method => channel.publishMethod(method)); * * // Just do: * channel.publishService(userService, { * get: async ({ userId }) => ({ name: `User ${userId}`, email: `user${userId}@example.com` }), * create: async ({ name, email }) => ({ id: '123', success: true }) * }); * ``` */ publishService(service, implementation, targetId) { const actualTargetId = targetId || this.channelId; const methods = implementService(service, implementation, actualTargetId); methods.forEach((method) => this.publishMethod(method)); } async invoke(targetId, methodId, input, inputSchema, outputSchema, timeout) { const traceId = this.generateTraceId(); if (inputSchema) { try { inputSchema.parse(input); } catch (error) { throw new ValidationError(`Input validation failed: ${error}`, traceId); } } // Check for local method first const localKey = `${targetId}:${methodId}`; const localMethod = this.methods.get(localKey); if (localMethod) { // Execute locally - no network transport needed try { const validatedInput = localMethod.input.parse(input); const result = await localMethod.handler(validatedInput); const validatedOutput = localMethod.output.parse(result); if (outputSchema) { return outputSchema.parse(validatedOutput); } return validatedOutput; } catch (error) { if (error instanceof z.ZodError) { throw new ValidationError(`Local method validation failed: ${error.message}`, traceId); } throw error; } } // Not local - send via transport const message = { callerId: this.channelId, targetId, traceId, methodId, payload: input, type: 'request', }; return new Promise((resolve, reject) => { const timeoutMs = timeout || this.defaultTimeout; const timeoutHandle = setTimeout(() => { this.pendingCalls.delete(traceId); reject(new TimeoutError(`Method call timed out after ${timeoutMs}ms`, traceId)); }, timeoutMs); this.pendingCalls.set(traceId, { resolve: (value) => { if (outputSchema) { try { const validated = outputSchema.parse(value); resolve(validated); } catch (error) { reject(new ValidationError(`Output validation failed: ${error}`, traceId)); } } else { resolve(value); } }, reject, timeout: timeoutHandle, }); // Send message to all connected transports if (this.transports.size === 0) { reject(new Error('No transports connected')); return; } const sendPromises = Array.from(this.transports).map((transport) => transport.send(message)); Promise.all(sendPromises).catch(reject); }); } getAvailableMethods(serviceId) { const service = this.services.get(serviceId); return service ? service.methods : []; } getConnectedServices() { return Array.from(this.services.values()); } async handleMessage(message) { try { switch (message.type) { case 'request': await this.handleRequest(message); break; case 'response': this.handleResponse(message); break; case 'error': this.handleError(message); break; } } catch (error) { console.error('Error handling message:', error); } } async handleRequest(message) { const key = `${message.targetId}:${message.methodId}`; const method = this.methods.get(key); if (!method) { await this.sendError(message, new MethodNotFoundError(message.methodId, message.traceId)); return; } try { const validatedInput = method.input.parse(message.payload); const result = await method.handler(validatedInput); const validatedOutput = method.output.parse(result); const response = { callerId: this.channelId, targetId: message.callerId, traceId: message.traceId, methodId: message.methodId, payload: validatedOutput, type: 'response', }; // Send response to all connected transports for (const transport of this.transports) { await transport.send(response); } } catch (error) { let rpcError; if (error instanceof RPCError) { rpcError = error; } else if (error instanceof z.ZodError) { rpcError = new ValidationError(error.message, message.traceId); } else { rpcError = new RPCError('INTERNAL_ERROR', error instanceof Error ? error.message : 'Unknown error', message.traceId); } await this.sendError(message, rpcError); } } handleResponse(message) { const pending = this.pendingCalls.get(message.traceId); if (pending) { clearTimeout(pending.timeout); this.pendingCalls.delete(message.traceId); pending.resolve(message.payload); } } handleError(message) { const pending = this.pendingCalls.get(message.traceId); if (pending) { clearTimeout(pending.timeout); this.pendingCalls.delete(message.traceId); const error = new RPCError(message.payload?.code || 'UNKNOWN_ERROR', message.payload?.message || 'Unknown error occurred', message.traceId); pending.reject(error); } } async sendError(originalMessage, error) { const errorMessage = { callerId: this.channelId, targetId: originalMessage.callerId, traceId: originalMessage.traceId, methodId: originalMessage.methodId, payload: error.toJSON(), type: 'error', }; try { // Send error to all connected transports for (const transport of this.transports) { await transport.send(errorMessage); } } catch (sendError) { console.error('Failed to send error message:', sendError); } } async publishServiceInfo() { const methods = Array.from(this.methods.values()).map((method) => ({ id: method.id, name: method.id.split('.').pop() || method.id, })); const serviceInfo = { id: this.channelId, methods, }; this.services.set(this.channelId, serviceInfo); } generateTraceId() { return `${this.channelId}-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; } } //# sourceMappingURL=channel.js.map