UNPKG

@cloudpss/ubrpc

Version:

340 lines (331 loc) 13 kB
import { connected } from '@cloudpss/fetch'; import { from, Observable, type Subscription } from 'rxjs'; import type { ConnectionID, RpcCallPayload, RpcMetadata, RpcNotifyPayload, RpcPayload, RpcSubscribePayload, } from './types/payload.js'; import type { Methods, Notifications, ObservableLike, RpcObject, RpcParameters, RpcReturns, Subjects, } from './types/utils.js'; import type { ObservableStatus, PromiseCallbacks } from './types/callbacks.js'; import { send } from './utils/messaging.js'; import { decodePayload, deserializeError, serializeError } from './utils/serialize.js'; import { logger } from './logger.js'; import { ReadyPromise } from './utils/ready.js'; import { shouldReconnectWebSocket, WebSocketAppCode } from './codes.js'; /** RPC 连接 */ export abstract class RpcSocket<TRemote extends object, TLocal extends object> { constructor(readonly id: ConnectionID) {} protected _localMetadata?: RpcMetadata; protected _remoteMetadata?: RpcMetadata; /** 本地认证信息 */ get localMetadata(): RpcMetadata | undefined { return this._localMetadata; } /** 远程认证信息 */ get remoteMetadata(): RpcMetadata | undefined { return this._remoteMetadata; } /** 连接是否已认证 */ get authenticated(): boolean { return this._remoteMetadata != null; } /** 用于响应调用的本地对象 */ protected abstract get local(): RpcObject<TLocal> | undefined; #socket?: WebSocket; /** 作为底层传输的 WebSocket */ get socket(): WebSocket { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); if (!this.#socket) throw new Error(`Socket not initialized`); return this.#socket; } protected set socket(value: WebSocket) { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); if (this.#socket === value) return; this.resetReady(); const oldValue = this.#socket; this.#socket = value; this.ready.settle(this.initSocket(oldValue, value)); } protected ready = new ReadyPromise(); /** 重置 ready Promise */ protected resetReady(): void { if (!this.ready.settled) return; this.ready = new ReadyPromise(); } readonly #handlers = Object.freeze({ open: (ev: Event) => this.onOpen(ev), close: (ev: CloseEvent) => this.onClose(ev), error: (ev: Event | ErrorEvent) => this.onError(ev), message: (ev: MessageEvent) => this.onMessage(ev), } as const); /** 初始化 WebSocket */ protected async initSocket(oldValue: WebSocket | undefined, newValue: WebSocket): Promise<void> { try { const { open, close, error, message } = this.#handlers; if (oldValue) { oldValue.removeEventListener('open', open); oldValue.removeEventListener('close', close); oldValue.removeEventListener('error', error); oldValue.removeEventListener('message', message); if (oldValue.readyState !== oldValue.CLOSED) { logger('[%s] close old socket', this.id); oldValue.close(WebSocketAppCode.REPLACED); } } await connected(newValue); const info = await this.authSocket(); if (this.#socket === newValue) { newValue.addEventListener('open', open); newValue.addEventListener('close', close); newValue.addEventListener('error', error); newValue.addEventListener('message', message); this._remoteMetadata = info ?? {}; } } catch (ex) { logger('[%s] connection initialize failed. error=%o', this.id, ex); this._remoteMetadata = undefined; throw ex; } } /** 认证 WebSocket */ protected abstract authSocket(): Promise<RpcMetadata>; /** 响应 WebSocket error */ protected onError(ev: ErrorEvent | Event): void { logger('[%s] socket error: %o', this.id, ev); this._remoteMetadata = undefined; this.resetReady(); } /** 响应 WebSocket open */ protected onOpen(_ev: Event): void { // } /** 响应 WebSocket close */ protected onClose(ev: CloseEvent): void { const { code } = ev; logger('[%s] socket closed, code=%d', this.id, code); this._remoteMetadata = undefined; if (!shouldReconnectWebSocket(code)) { this.destroy(); } else { this.resetReady(); } } /** 响应 WebSocket message */ protected onMessage(ev: MessageEvent): void { const payload = decodePayload(ev.data); let error; if (payload) { const handled = this.onPayload(payload); if (!handled) error = [payload.seq, `Unrecognized message, not handled.`] as const; } else { error = [this.nextSeq(), `Invalid message, unknown format.`] as const; } if (error) { void this.sendPayload('error', { seq: error[0], error: serializeError(new SyntaxError(error[1])), }); } } /** 响应 Rpc 消息 */ protected onPayload(payload: RpcPayload): boolean { switch (payload.type) { case 'call': case 'notify': void this.localCall(payload); return true; case 'return': { const pending = this.#pendingCalls.get(payload.seq); // 即使不存在等待的请求,也认为响应是有效的 if (!pending) return true; this.#pendingCalls.delete(payload.seq); if (payload.error) { pending[1](deserializeError(payload.error)); } else { pending[0](payload.result); } return true; } case 'subscribe': void this.localSubscribe(payload); return true; case 'unsubscribe': { const subscription = this.#localSubscription.get(payload.seq); // 即使不存在对应订阅,也认为响应是有效的 if (!subscription) return true; subscription.unsubscribe(); this.#localSubscription.delete(payload.seq); return true; } case 'publish': { const status = this.#pendingSubscriptions.get(payload.seq); // 即使不存在对应订阅,也认为响应是有效的 if (!status) { void this.sendPayload('unsubscribe', { seq: payload.seq }); return true; } if (payload.error) { status[1] = true; status[0].error(deserializeError(payload.error)); } else if (payload.complete) { status[1] = true; status[0].complete(); } else { status[0].next(payload.next); } return true; } case 'error': return true; case 'auth': default: return false; } } /** 调用本地方法 */ protected async localCall(payload: RpcCallPayload | RpcNotifyPayload): Promise<void> { const noReturn = payload.type === 'notify'; const method = this.local ? this.local[payload.method as Methods<RpcObject<TLocal>>] : undefined; const { seq } = payload; if (typeof method != 'function') { if (noReturn) return; return this.sendPayload('return', { seq, error: serializeError(new TypeError(`${payload.method} is not a function`)), }); } try { const result: unknown = await Reflect.apply(method, this.local, payload.args); if (noReturn) return; return this.sendPayload('return', { seq, result }); } catch (ex) { if (noReturn) return; return this.sendPayload('return', { seq, error: serializeError(ex) }); } } readonly #localSubscription = new Map<number, Subscription>(); /** 调用本地方法 */ protected async localSubscribe(payload: RpcSubscribePayload): Promise<void> { const method = this.local ? this.local[payload.method as Subjects<RpcObject<TLocal>>] : undefined; const { seq } = payload; if (typeof method != 'function') { return this.sendPayload('publish', { seq, error: serializeError(new TypeError(`${payload.method} is not a function`)), }); } try { const result = from((await Reflect.apply(method, this.local, payload.args)) as ObservableLike<unknown>); const subscription = result.subscribe({ next: (value) => { void this.sendPayload('publish', { seq, next: value }); }, error: (err) => { this.#localSubscription.delete(payload.seq); void this.sendPayload('publish', { seq, error: serializeError(err) }); }, complete: () => { this.#localSubscription.delete(payload.seq); void this.sendPayload('publish', { seq, complete: true }); }, }); this.#localSubscription.set(seq, subscription); } catch (ex) { return this.sendPayload('publish', { seq, error: serializeError(ex) }); } } /** 发送数据 */ protected async sendPayload<T extends RpcPayload['type']>( type: T, info: Omit<RpcPayload & { type: T }, 'type'>, ): Promise<void> { if (this.#destroyed) throw new Error(`RPC Socket destroyed.`); await this.ready.value; send(this.socket, type, info); } /** 序列号 */ protected seq = 0; /** 获取下一个序列号 */ protected nextSeq(): number { const { seq } = this; this.seq += 2; return seq; } readonly #pendingCalls = new Map<number, PromiseCallbacks>(); /** 调用远程方法 */ // eslint-disable-next-line @typescript-eslint/promise-function-async call<TMethod extends Methods<TRemote>>( method: TMethod, ...args: RpcParameters<TRemote[TMethod]> ): Promise<RpcReturns<TRemote[TMethod]>> { return new Promise((resolve, reject) => { const seq = this.nextSeq(); void this.sendPayload('call', { seq, method, args }); this.#pendingCalls.set(seq, [resolve as (result: unknown) => void, reject]); }); } /** 调用远程方法,放弃返回值 */ notify<TNotification extends Notifications<TRemote>>( method: TNotification, ...args: RpcParameters<TRemote[TNotification]> ): void { const seq = this.nextSeq(); void this.sendPayload('notify', { seq, method, args }); } readonly #pendingSubscriptions = new Map<number, ObservableStatus>(); /** 调用远程订阅 */ subscribe<TSubject extends Subjects<TRemote>>( method: TSubject, ...args: RpcParameters<TRemote[TSubject]> ): Observable<RpcReturns<TRemote[TSubject]>> { return new Observable((subscriber) => { const seq = this.nextSeq(); const status: ObservableStatus<RpcReturns<TRemote[TSubject]>> = [subscriber, false]; void this.sendPayload('subscribe', { seq, method, args }); this.#pendingSubscriptions.set(seq, status); return () => { this.#pendingSubscriptions.delete(seq); if (!status[1]) void this.sendPayload('unsubscribe', { seq }); }; }); } #destroyed = false; /** 是否已结束 */ get destroyed(): boolean { return this.#destroyed; } /** 结束 */ destroy(): void { if (this.#destroyed) return; logger('[%s] socket destroyed', this.id); if (this.#socket) { this.#socket.close(1000); this.#socket = undefined; } for (const s of this.#localSubscription.values()) { s.unsubscribe(); } this.#localSubscription.clear(); for (const [, reject] of this.#pendingCalls.values()) { reject(new Error(`RPC Socket closed.`)); } this.#pendingCalls.clear(); for (const s of this.#pendingSubscriptions.values()) { s[1] = true; s[0].error(new Error(`RPC Socket closed.`)); } this.#pendingSubscriptions.clear(); this.#destroyed = true; } }