@xtr-dev/zod-rpc
Version:
Simple, type-safe RPC library with Zod validation and automatic TypeScript inference
284 lines • 10.5 kB
JavaScript
import { z } from 'zod';
import { RPCError, ValidationError, MethodNotFoundError, TimeoutError } from './errors.js';
import { implementService } from './service.js';
/**
* Core communication channel that manages RPC method calls, message routing, and transport connections.
* This is the central orchestrator for all RPC communication between clients and servers.
*
* @example
* ```typescript
* const channel = new Channel('client-123', 30000);
*
* // Connect to transport
* const transport = createWebSocketTransport('ws://localhost:8080');
* await channel.connect(transport);
*
* // Call remote method
* const result = await channel.invoke(
* 'server',
* 'user.get',
* { userId: '123' },
* inputSchema,
* outputSchema
* );
* ```
*
* @group Core Classes
*/
export class Channel {
channelId;
defaultTimeout;
methods = new Map(); // targetId:methodId -> method
pendingCalls = new Map();
services = new Map();
transports = new Set();
connected = false;
constructor(channelId, defaultTimeout = 30000) {
this.channelId = channelId;
this.defaultTimeout = defaultTimeout;
// Channel can now exist without any transports initially
}
async connect(...transports) {
// Add transports to the set and connect them
for (const transport of transports) {
this.transports.add(transport);
transport.onMessage(this.handleMessage.bind(this));
await transport.connect();
}
this.connected = true;
await this.publishServiceInfo();
}
async disconnect() {
for (const [traceId, pending] of this.pendingCalls) {
clearTimeout(pending.timeout);
pending.reject(new TimeoutError(`Connection closed`, traceId));
}
this.pendingCalls.clear();
// Disconnect all transports
for (const transport of this.transports) {
await transport.disconnect();
}
this.transports.clear();
this.connected = false;
}
publishMethod(definition) {
const key = `${definition.targetId}:${definition.id}`;
this.methods.set(key, definition);
}
/**
* Publish a service implementation to this channel, making all its methods available for RPC calls.
* This is a convenience method that combines implementService and publishMethod.
*
* @template T - Service methods record type
* @param service - The service definition with schemas
* @param implementation - The actual implementation functions
* @param targetId - Optional target identifier (defaults to channelId)
*
* @example
* ```typescript
* const channel = new Channel('server');
*
* // Instead of:
* // const methods = implementService(userService, implementation, 'server');
* // methods.forEach(method => channel.publishMethod(method));
*
* // Just do:
* channel.publishService(userService, {
* get: async ({ userId }) => ({ name: `User ${userId}`, email: `user${userId}@example.com` }),
* create: async ({ name, email }) => ({ id: '123', success: true })
* });
* ```
*/
publishService(service, implementation, targetId) {
const actualTargetId = targetId || this.channelId;
const methods = implementService(service, implementation, actualTargetId);
methods.forEach((method) => this.publishMethod(method));
}
async invoke(targetId, methodId, input, inputSchema, outputSchema, timeout) {
const traceId = this.generateTraceId();
if (inputSchema) {
try {
inputSchema.parse(input);
}
catch (error) {
throw new ValidationError(`Input validation failed: ${error}`, traceId);
}
}
// Check for local method first
const localKey = `${targetId}:${methodId}`;
const localMethod = this.methods.get(localKey);
if (localMethod) {
// Execute locally - no network transport needed
try {
const validatedInput = localMethod.input.parse(input);
const result = await localMethod.handler(validatedInput);
const validatedOutput = localMethod.output.parse(result);
if (outputSchema) {
return outputSchema.parse(validatedOutput);
}
return validatedOutput;
}
catch (error) {
if (error instanceof z.ZodError) {
throw new ValidationError(`Local method validation failed: ${error.message}`, traceId);
}
throw error;
}
}
// Not local - send via transport
const message = {
callerId: this.channelId,
targetId,
traceId,
methodId,
payload: input,
type: 'request',
};
return new Promise((resolve, reject) => {
const timeoutMs = timeout || this.defaultTimeout;
const timeoutHandle = setTimeout(() => {
this.pendingCalls.delete(traceId);
reject(new TimeoutError(`Method call timed out after ${timeoutMs}ms`, traceId));
}, timeoutMs);
this.pendingCalls.set(traceId, {
resolve: (value) => {
if (outputSchema) {
try {
const validated = outputSchema.parse(value);
resolve(validated);
}
catch (error) {
reject(new ValidationError(`Output validation failed: ${error}`, traceId));
}
}
else {
resolve(value);
}
},
reject,
timeout: timeoutHandle,
});
// Send message to all connected transports
if (this.transports.size === 0) {
reject(new Error('No transports connected'));
return;
}
const sendPromises = Array.from(this.transports).map((transport) => transport.send(message));
Promise.all(sendPromises).catch(reject);
});
}
getAvailableMethods(serviceId) {
const service = this.services.get(serviceId);
return service ? service.methods : [];
}
getConnectedServices() {
return Array.from(this.services.values());
}
async handleMessage(message) {
try {
switch (message.type) {
case 'request':
await this.handleRequest(message);
break;
case 'response':
this.handleResponse(message);
break;
case 'error':
this.handleError(message);
break;
}
}
catch (error) {
console.error('Error handling message:', error);
}
}
async handleRequest(message) {
const key = `${message.targetId}:${message.methodId}`;
const method = this.methods.get(key);
if (!method) {
await this.sendError(message, new MethodNotFoundError(message.methodId, message.traceId));
return;
}
try {
const validatedInput = method.input.parse(message.payload);
const result = await method.handler(validatedInput);
const validatedOutput = method.output.parse(result);
const response = {
callerId: this.channelId,
targetId: message.callerId,
traceId: message.traceId,
methodId: message.methodId,
payload: validatedOutput,
type: 'response',
};
// Send response to all connected transports
for (const transport of this.transports) {
await transport.send(response);
}
}
catch (error) {
let rpcError;
if (error instanceof RPCError) {
rpcError = error;
}
else if (error instanceof z.ZodError) {
rpcError = new ValidationError(error.message, message.traceId);
}
else {
rpcError = new RPCError('INTERNAL_ERROR', error instanceof Error ? error.message : 'Unknown error', message.traceId);
}
await this.sendError(message, rpcError);
}
}
handleResponse(message) {
const pending = this.pendingCalls.get(message.traceId);
if (pending) {
clearTimeout(pending.timeout);
this.pendingCalls.delete(message.traceId);
pending.resolve(message.payload);
}
}
handleError(message) {
const pending = this.pendingCalls.get(message.traceId);
if (pending) {
clearTimeout(pending.timeout);
this.pendingCalls.delete(message.traceId);
const error = new RPCError(message.payload?.code || 'UNKNOWN_ERROR', message.payload?.message || 'Unknown error occurred', message.traceId);
pending.reject(error);
}
}
async sendError(originalMessage, error) {
const errorMessage = {
callerId: this.channelId,
targetId: originalMessage.callerId,
traceId: originalMessage.traceId,
methodId: originalMessage.methodId,
payload: error.toJSON(),
type: 'error',
};
try {
// Send error to all connected transports
for (const transport of this.transports) {
await transport.send(errorMessage);
}
}
catch (sendError) {
console.error('Failed to send error message:', sendError);
}
}
async publishServiceInfo() {
const methods = Array.from(this.methods.values()).map((method) => ({
id: method.id,
name: method.id.split('.').pop() || method.id,
}));
const serviceInfo = {
id: this.channelId,
methods,
};
this.services.set(this.channelId, serviceInfo);
}
generateTraceId() {
return `${this.channelId}-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
}
}
//# sourceMappingURL=channel.js.map