UNPKG

@cloudpss/ubrpc

Version:

Rpc server/client build on websocket and ubjson.

346 lines (341 loc) 13.5 kB
import { connected } from '@cloudpss/fetch'; import { from, Observable, type Subscriber, type Subscription } from 'rxjs'; import type { ConnectionID, RpcCallPayload, RpcMetadata, RpcNotifyPayload, RpcPayload, RpcSubscribePayload, } from './types/payload.js'; import type { Methods, Notifications, ObservableLike, RpcObject, RpcParameters, RpcReturns, Subjects, } from './types/utils.js'; import { send } from './utils/messaging.js'; import { decodePayload, deserializeError, serializeError } from './utils/serialize.js'; import { logger } from './logger.js'; import { WebSocketAppCode } from './codes.js'; /** RPC 连接 */ export abstract class RpcSocket<TRemote extends object, TLocal extends object> { constructor( readonly id: ConnectionID, local?: RpcObject<TLocal>, ) { this.__local = local; this.makeReady(); } protected _localMetadata?: RpcMetadata; protected _remoteMetadata?: RpcMetadata; /** 本地认证信息 */ get localMetadata(): RpcMetadata | undefined { return this._localMetadata; } /** 远程认证信息 */ get remoteMetadata(): RpcMetadata | undefined { return this._remoteMetadata; } /** 连接是否已认证 */ get authenticated(): boolean { return this._remoteMetadata != null; } private readonly __local?: RpcObject<TLocal>; /** 用于响应调用的本地对象 */ protected get local(): RpcObject<TLocal> | undefined { return this.__local; } private __socket?: WebSocket; /** 作为底层传输的 WebSocket */ get socket(): WebSocket { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); if (!this.__socket) throw new Error(`Socket not initialized`); return this.__socket; } protected set socket(value: WebSocket) { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); if (this.__socket === value) return; const callbacks = this.makeReady(); const oldValue = this.__socket; this.__socket = value; void this.initSocket(oldValue, value).then(...callbacks); } /** 创建 ready Promise */ private makeReady(): Parameters<ConstructorParameters<typeof Promise<void>>[0]> { if (this.__readyCallbacks != null) return this.__readyCallbacks; const ready = new Promise<void>((...args) => (this.__readyCallbacks = args)); ready .catch(() => { // Avoid unhandled rejection }) .finally(() => { if (this.ready === ready) this.__readyCallbacks = undefined; }); this.ready = ready; if (!this.__readyCallbacks) throw new Error(`Bad promise`); return this.__readyCallbacks; } private __readyCallbacks?: Parameters<ConstructorParameters<typeof Promise<void>>[0]>; protected ready!: Promise<void>; private readonly __handlers = Object.freeze({ open: (ev: Event) => this.onOpen(ev), close: (ev: CloseEvent) => this.onClose(ev), error: (ev: Event | ErrorEvent) => this.onError(ev), message: (ev: MessageEvent) => this.onMessage(ev), } as const); /** 初始化 WebSocket */ protected async initSocket(oldValue: WebSocket | undefined, newValue: WebSocket): Promise<void> { try { if (oldValue) { oldValue.removeEventListener('open', this.__handlers.open); oldValue.removeEventListener('close', this.__handlers.close); oldValue.removeEventListener('error', this.__handlers.error); oldValue.removeEventListener('message', this.__handlers.message); if (oldValue.readyState === oldValue.OPEN || oldValue.readyState === oldValue.CONNECTING) { logger('[%s] close old socket', this.id); oldValue.close(WebSocketAppCode.REPLACED); } } await connected(newValue); const info = await this.authSocket(); if (this.__socket === newValue) { newValue.addEventListener('open', this.__handlers.open); newValue.addEventListener('close', this.__handlers.close); newValue.addEventListener('error', this.__handlers.error); newValue.addEventListener('message', this.__handlers.message); this._remoteMetadata = info ?? {}; } } catch (ex) { logger('[%s] connection initialize failed. error=', this.id, ex); this._remoteMetadata = undefined; throw ex; } } /** 认证 WebSocket */ protected abstract authSocket(): Promise<RpcMetadata>; /** 响应 WebSocket error */ protected onError(_ev: ErrorEvent | Event): void { this._remoteMetadata = undefined; this.makeReady(); } /** 响应 WebSocket open */ protected onOpen(_ev: Event): void { // } /** 响应 WebSocket close */ protected onClose(ev: CloseEvent): void { this._remoteMetadata = undefined; if (ev.code === 1000) { this.destroy(); } else { this.makeReady(); } } /** 响应 WebSocket message */ protected onMessage(ev: MessageEvent): void { const payload = decodePayload(ev.data); let error; if (payload) { const handled = this.onPayload(payload); if (!handled) error = [payload.seq, `Unrecognized message, not handled.`] as const; } else { error = [this.nextSeq(), `Invalid message, unknown format.`] as const; } if (error) { void this.sendPayload('error', { seq: error[0], error: serializeError(new SyntaxError(error[1])), }); } } /** 响应 Rpc 消息 */ protected onPayload(payload: RpcPayload): boolean { switch (payload.type) { case 'call': case 'notify': void this.localCall(payload); return true; case 'return': { const pending = this.pendingCalls.get(payload.seq); // 即使不存在等待的请求,也认为响应是有效的 if (!pending) return true; this.pendingCalls.delete(payload.seq); if (payload.error) { pending[1](deserializeError(payload.error)); } else { pending[0](payload.result); } return true; } case 'subscribe': void this.localSubscribe(payload); return true; case 'unsubscribe': { const subscription = this.localSubscription.get(payload.seq); // 即使不存在对应订阅,也认为响应是有效的 if (!subscription) return true; subscription.unsubscribe(); this.localSubscription.delete(payload.seq); return true; } case 'publish': { const subscriber = this.pendingSubscriptions.get(payload.seq); // 即使不存在对应订阅,也认为响应是有效的 if (!subscriber) { void this.sendPayload('unsubscribe', { seq: payload.seq }); return true; } if (payload.error) { subscriber.error(deserializeError(payload.error)); } else if (payload.complete) { subscriber.complete(); } else { subscriber.next(payload.next); } return true; } case 'error': return true; case 'auth': default: return false; } } /** 调用本地方法 */ protected async localCall(payload: RpcCallPayload | RpcNotifyPayload): Promise<void> { const noReturn = payload.type === 'notify'; const method = this.local ? this.local[payload.method as Methods<RpcObject<TLocal>>] : undefined; const { seq } = payload; if (typeof method != 'function') { if (noReturn) return; return this.sendPayload('return', { seq, error: serializeError(new TypeError(`${payload.method} is not a function`)), }); } try { const result: unknown = await Reflect.apply(method, this.local, payload.args); if (noReturn) return; return this.sendPayload('return', { seq, result }); } catch (ex) { if (noReturn) return; return this.sendPayload('return', { seq, error: serializeError(ex) }); } } private readonly localSubscription = new Map<number, Subscription>(); /** 调用本地方法 */ protected async localSubscribe(payload: RpcSubscribePayload): Promise<void> { const method = this.local ? this.local[payload.method as Subjects<RpcObject<TLocal>>] : undefined; const { seq } = payload; if (typeof method != 'function') { return this.sendPayload('publish', { seq, error: serializeError(new TypeError(`${payload.method} is not a function`)), }); } try { const result = from((await Reflect.apply(method, this.local, payload.args)) as ObservableLike<unknown>); const subscription = result.subscribe({ next: (value) => { void this.sendPayload('publish', { seq, next: value }); }, error: (err) => { this.localSubscription.delete(payload.seq); void this.sendPayload('publish', { seq, error: serializeError(err) }); }, complete: () => { this.localSubscription.delete(payload.seq); void this.sendPayload('publish', { seq, complete: true }); }, }); this.localSubscription.set(seq, subscription); } catch (ex) { return this.sendPayload('publish', { seq, error: serializeError(ex) }); } } /** 发送数据 */ protected async sendPayload<T extends RpcPayload['type']>( type: T, info: Omit<RpcPayload & { type: T }, 'type'>, ): Promise<void> { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); await this.ready; send(this.socket, type, info); } /** 序列号 */ protected seq = 0; /** 获取下一个序列号 */ protected nextSeq(): number { const { seq } = this; this.seq += 2; return seq; } private readonly pendingCalls = new Map< number, [resolve: (result: unknown) => void, reject: (error: Error) => void] >(); /** 调用远程方法 */ // eslint-disable-next-line @typescript-eslint/promise-function-async call<TMethod extends Methods<TRemote>>( method: TMethod, ...args: RpcParameters<TRemote[TMethod]> ): Promise<RpcReturns<TRemote[TMethod]>> { return new Promise((resolve, reject) => { const seq = this.nextSeq(); void this.sendPayload('call', { seq, method, args }); this.pendingCalls.set(seq, [resolve as (result: unknown) => void, reject]); }); } /** 调用远程方法,放弃返回值 */ notify<TNotification extends Notifications<TRemote>>( method: TNotification, ...args: RpcParameters<TRemote[TNotification]> ): void { const seq = this.nextSeq(); void this.sendPayload('notify', { seq, method, args }); } private readonly pendingSubscriptions = new Map<number, Subscriber<unknown>>(); /** 调用远程订阅 */ subscribe<TSubject extends Subjects<TRemote>>( method: TSubject, ...args: RpcParameters<TRemote[TSubject]> ): Observable<RpcReturns<TRemote[TSubject]>> { return new Observable((subscriber) => { const seq = this.nextSeq(); void this.sendPayload('subscribe', { seq, method, args }); this.pendingSubscriptions.set(seq, subscriber); return () => { this.pendingSubscriptions.delete(seq); void this.sendPayload('unsubscribe', { seq }); }; }); } #destroyed = false; /** 是否已结束 */ get destroyed(): boolean { return this.#destroyed; } /** 结束 */ destroy(): void { if (this.#destroyed) return; logger('[%s] socket destroyed', this.id); if ( this.__socket && (this.__socket.readyState === this.__socket.CONNECTING || this.__socket.readyState === this.__socket.OPEN) ) { this.__socket.close(1000); } this.__socket = undefined; for (const s of this.localSubscription.values()) s.unsubscribe(); this.localSubscription.clear(); for (const [, reject] of this.pendingCalls.values()) reject(new Error(`RPC Socket closed.`)); this.pendingCalls.clear(); for (const s of this.pendingSubscriptions.values()) s.error(new Error(`RPC Socket closed.`)); this.pendingSubscriptions.clear(); this.#destroyed = true; } }