@cloudpss/ubrpc
Version:
Rpc server/client build on websocket and ubjson.
346 lines (341 loc) • 13.5 kB
text/typescript
import { connected } from '@cloudpss/fetch';
import { from, Observable, type Subscriber, type Subscription } from 'rxjs';
import type {
ConnectionID,
RpcCallPayload,
RpcMetadata,
RpcNotifyPayload,
RpcPayload,
RpcSubscribePayload,
} from './types/payload.js';
import type {
Methods,
Notifications,
ObservableLike,
RpcObject,
RpcParameters,
RpcReturns,
Subjects,
} from './types/utils.js';
import { send } from './utils/messaging.js';
import { decodePayload, deserializeError, serializeError } from './utils/serialize.js';
import { logger } from './logger.js';
import { WebSocketAppCode } from './codes.js';
/** RPC 连接 */
export abstract class RpcSocket<TRemote extends object, TLocal extends object> {
constructor(
readonly id: ConnectionID,
local?: RpcObject<TLocal>,
) {
this.__local = local;
this.makeReady();
}
protected _localMetadata?: RpcMetadata;
protected _remoteMetadata?: RpcMetadata;
/** 本地认证信息 */
get localMetadata(): RpcMetadata | undefined {
return this._localMetadata;
}
/** 远程认证信息 */
get remoteMetadata(): RpcMetadata | undefined {
return this._remoteMetadata;
}
/** 连接是否已认证 */
get authenticated(): boolean {
return this._remoteMetadata != null;
}
private readonly __local?: RpcObject<TLocal>;
/** 用于响应调用的本地对象 */
protected get local(): RpcObject<TLocal> | undefined {
return this.__local;
}
private __socket?: WebSocket;
/** 作为底层传输的 WebSocket */
get socket(): WebSocket {
if (this.#destroyed) throw new Error(`RPC Socket destroyed.`);
if (!this.__socket) throw new Error(`Socket not initialized`);
return this.__socket;
}
protected set socket(value: WebSocket) {
if (this.#destroyed) throw new Error(`RPC Socket destroyed.`);
if (this.__socket === value) return;
const callbacks = this.makeReady();
const oldValue = this.__socket;
this.__socket = value;
void this.initSocket(oldValue, value).then(...callbacks);
}
/** 创建 ready Promise */
private makeReady(): Parameters<ConstructorParameters<typeof Promise<void>>[0]> {
if (this.__readyCallbacks != null) return this.__readyCallbacks;
const ready = new Promise<void>((...args) => (this.__readyCallbacks = args));
ready
.catch(() => {
// Avoid unhandled rejection
})
.finally(() => {
if (this.ready === ready) this.__readyCallbacks = undefined;
});
this.ready = ready;
if (!this.__readyCallbacks) throw new Error(`Bad promise`);
return this.__readyCallbacks;
}
private __readyCallbacks?: Parameters<ConstructorParameters<typeof Promise<void>>[0]>;
protected ready!: Promise<void>;
private readonly __handlers = Object.freeze({
open: (ev: Event) => this.onOpen(ev),
close: (ev: CloseEvent) => this.onClose(ev),
error: (ev: Event | ErrorEvent) => this.onError(ev),
message: (ev: MessageEvent) => this.onMessage(ev),
} as const);
/** 初始化 WebSocket */
protected async initSocket(oldValue: WebSocket | undefined, newValue: WebSocket): Promise<void> {
try {
if (oldValue) {
oldValue.removeEventListener('open', this.__handlers.open);
oldValue.removeEventListener('close', this.__handlers.close);
oldValue.removeEventListener('error', this.__handlers.error);
oldValue.removeEventListener('message', this.__handlers.message);
if (oldValue.readyState === oldValue.OPEN || oldValue.readyState === oldValue.CONNECTING) {
logger('[%s] close old socket', this.id);
oldValue.close(WebSocketAppCode.REPLACED);
}
}
await connected(newValue);
const info = await this.authSocket();
if (this.__socket === newValue) {
newValue.addEventListener('open', this.__handlers.open);
newValue.addEventListener('close', this.__handlers.close);
newValue.addEventListener('error', this.__handlers.error);
newValue.addEventListener('message', this.__handlers.message);
this._remoteMetadata = info ?? {};
}
} catch (ex) {
logger('[%s] connection initialize failed. error=', this.id, ex);
this._remoteMetadata = undefined;
throw ex;
}
}
/** 认证 WebSocket */
protected abstract authSocket(): Promise<RpcMetadata>;
/** 响应 WebSocket error */
protected onError(_ev: ErrorEvent | Event): void {
this._remoteMetadata = undefined;
this.makeReady();
}
/** 响应 WebSocket open */
protected onOpen(_ev: Event): void {
//
}
/** 响应 WebSocket close */
protected onClose(ev: CloseEvent): void {
this._remoteMetadata = undefined;
if (ev.code === 1000) {
this.destroy();
} else {
this.makeReady();
}
}
/** 响应 WebSocket message */
protected onMessage(ev: MessageEvent): void {
const payload = decodePayload(ev.data);
let error;
if (payload) {
const handled = this.onPayload(payload);
if (!handled) error = [payload.seq, `Unrecognized message, not handled.`] as const;
} else {
error = [this.nextSeq(), `Invalid message, unknown format.`] as const;
}
if (error) {
void this.sendPayload('error', {
seq: error[0],
error: serializeError(new SyntaxError(error[1])),
});
}
}
/** 响应 Rpc 消息 */
protected onPayload(payload: RpcPayload): boolean {
switch (payload.type) {
case 'call':
case 'notify':
void this.localCall(payload);
return true;
case 'return': {
const pending = this.pendingCalls.get(payload.seq);
// 即使不存在等待的请求,也认为响应是有效的
if (!pending) return true;
this.pendingCalls.delete(payload.seq);
if (payload.error) {
pending[1](deserializeError(payload.error));
} else {
pending[0](payload.result);
}
return true;
}
case 'subscribe':
void this.localSubscribe(payload);
return true;
case 'unsubscribe': {
const subscription = this.localSubscription.get(payload.seq);
// 即使不存在对应订阅,也认为响应是有效的
if (!subscription) return true;
subscription.unsubscribe();
this.localSubscription.delete(payload.seq);
return true;
}
case 'publish': {
const subscriber = this.pendingSubscriptions.get(payload.seq);
// 即使不存在对应订阅,也认为响应是有效的
if (!subscriber) {
void this.sendPayload('unsubscribe', { seq: payload.seq });
return true;
}
if (payload.error) {
subscriber.error(deserializeError(payload.error));
} else if (payload.complete) {
subscriber.complete();
} else {
subscriber.next(payload.next);
}
return true;
}
case 'error':
return true;
case 'auth':
default:
return false;
}
}
/** 调用本地方法 */
protected async localCall(payload: RpcCallPayload | RpcNotifyPayload): Promise<void> {
const noReturn = payload.type === 'notify';
const method = this.local ? this.local[payload.method as Methods<RpcObject<TLocal>>] : undefined;
const { seq } = payload;
if (typeof method != 'function') {
if (noReturn) return;
return this.sendPayload('return', {
seq,
error: serializeError(new TypeError(`${payload.method} is not a function`)),
});
}
try {
const result: unknown = await Reflect.apply(method, this.local, payload.args);
if (noReturn) return;
return this.sendPayload('return', { seq, result });
} catch (ex) {
if (noReturn) return;
return this.sendPayload('return', { seq, error: serializeError(ex) });
}
}
private readonly localSubscription = new Map<number, Subscription>();
/** 调用本地方法 */
protected async localSubscribe(payload: RpcSubscribePayload): Promise<void> {
const method = this.local ? this.local[payload.method as Subjects<RpcObject<TLocal>>] : undefined;
const { seq } = payload;
if (typeof method != 'function') {
return this.sendPayload('publish', {
seq,
error: serializeError(new TypeError(`${payload.method} is not a function`)),
});
}
try {
const result = from((await Reflect.apply(method, this.local, payload.args)) as ObservableLike<unknown>);
const subscription = result.subscribe({
next: (value) => {
void this.sendPayload('publish', { seq, next: value });
},
error: (err) => {
this.localSubscription.delete(payload.seq);
void this.sendPayload('publish', { seq, error: serializeError(err) });
},
complete: () => {
this.localSubscription.delete(payload.seq);
void this.sendPayload('publish', { seq, complete: true });
},
});
this.localSubscription.set(seq, subscription);
} catch (ex) {
return this.sendPayload('publish', { seq, error: serializeError(ex) });
}
}
/** 发送数据 */
protected async sendPayload<T extends RpcPayload['type']>(
type: T,
info: Omit<RpcPayload & { type: T }, 'type'>,
): Promise<void> {
if (this.#destroyed) throw new Error(`RPC Socket destroyed.`);
await this.ready;
send(this.socket, type, info);
}
/** 序列号 */
protected seq = 0;
/** 获取下一个序列号 */
protected nextSeq(): number {
const { seq } = this;
this.seq += 2;
return seq;
}
private readonly pendingCalls = new Map<
number,
[resolve: (result: unknown) => void, reject: (error: Error) => void]
>();
/** 调用远程方法 */
// eslint-disable-next-line @typescript-eslint/promise-function-async
call<TMethod extends Methods<TRemote>>(
method: TMethod,
...args: RpcParameters<TRemote[TMethod]>
): Promise<RpcReturns<TRemote[TMethod]>> {
return new Promise((resolve, reject) => {
const seq = this.nextSeq();
void this.sendPayload('call', { seq, method, args });
this.pendingCalls.set(seq, [resolve as (result: unknown) => void, reject]);
});
}
/** 调用远程方法,放弃返回值 */
notify<TNotification extends Notifications<TRemote>>(
method: TNotification,
...args: RpcParameters<TRemote[TNotification]>
): void {
const seq = this.nextSeq();
void this.sendPayload('notify', { seq, method, args });
}
private readonly pendingSubscriptions = new Map<number, Subscriber<unknown>>();
/** 调用远程订阅 */
subscribe<TSubject extends Subjects<TRemote>>(
method: TSubject,
...args: RpcParameters<TRemote[TSubject]>
): Observable<RpcReturns<TRemote[TSubject]>> {
return new Observable((subscriber) => {
const seq = this.nextSeq();
void this.sendPayload('subscribe', { seq, method, args });
this.pendingSubscriptions.set(seq, subscriber);
return () => {
this.pendingSubscriptions.delete(seq);
void this.sendPayload('unsubscribe', { seq });
};
});
}
#destroyed = false;
/** 是否已结束 */
get destroyed(): boolean {
return this.#destroyed;
}
/** 结束 */
destroy(): void {
if (this.#destroyed) return;
logger('[%s] socket destroyed', this.id);
if (
this.__socket &&
(this.__socket.readyState === this.__socket.CONNECTING || this.__socket.readyState === this.__socket.OPEN)
) {
this.__socket.close(1000);
}
this.__socket = undefined;
for (const s of this.localSubscription.values()) s.unsubscribe();
this.localSubscription.clear();
for (const [, reject] of this.pendingCalls.values()) reject(new Error(`RPC Socket closed.`));
this.pendingCalls.clear();
for (const s of this.pendingSubscriptions.values()) s.error(new Error(`RPC Socket closed.`));
this.pendingSubscriptions.clear();
this.#destroyed = true;
}
}