@cloudpss/ubrpc
Version:
Rpc server/client build on websocket and ubjson.
163 lines (153 loc) • 6.24 kB
text/typescript
import { waitAuth } from './auth.js';
import { WebSocketAppCode } from './codes.js';
import { logger } from './logger.js';
import { RpcSocket } from './socket.js';
import type { ConnectionID, RpcMetadata } from './types/payload.js';
import type { RpcObject } from './types/utils.js';
import { send } from './utils/messaging.js';
import { serializeError } from './utils/serialize.js';
import { VERSION } from './version.js';
const kOnClose = Symbol('kOnClose');
const kMetadata = Symbol('kMetadata');
const kReplaceSocket = Symbol('kReplaceSocket');
/** 认证过的 socket */
type WebSocketWithMetadata = WebSocket & { [kMetadata]: RpcMetadata };
/**
* 客户端认证方法
* @param metadata 客户端传入的认证元数据
* @returns 服务端返回客户端的认证元数据
* @throws 认证失败时抛出异常
*/
type RpcClientAuthenticator = (metadata: RpcMetadata) => RpcMetadata | Promise<RpcMetadata>;
let tempIdGen = 0;
/** 由 WS Server 建立的 RPC 连接组 */
export class RpcServer<TRemote extends object, TLocal extends object> {
constructor(
/** 服务端 RPC 实现 */
local?: RpcObject<TLocal> | ((socket: RpcServerSocket<TRemote, TLocal>) => RpcObject<TLocal>),
/** 客户端认证方法 */
readonly authenticator: RpcClientAuthenticator = () => ({}),
) {
if (local == null) {
this.local = undefined;
} else if (typeof local == 'function') {
this.local = local as (socket: RpcServerSocket<TRemote, TLocal>) => RpcObject<TLocal>;
} else {
this.local = () => local;
}
}
/** 用于响应调用的本地对象 */
readonly local?: (socket: RpcServerSocket<TRemote, TLocal>) => RpcObject<TLocal>;
/** 已建立连接的客户端 */
protected readonly _sockets = new Map<ConnectionID, RpcServerSocket<TRemote, TLocal>>();
/** 已建立连接的客户端 */
get sockets(): ReadonlyMap<ConnectionID, RpcServerSocket<TRemote, TLocal>> {
return this._sockets;
}
/** WebSocket 连接后调用此方法建立 RPC 连接 */
async connect(socket: WebSocket): Promise<RpcServerSocket<TRemote, TLocal>> {
const tempId = `#${++tempIdGen}`;
logger('[%s] incoming connection', tempId);
socket.binaryType = 'arraybuffer';
const s = socket as WebSocketWithMetadata;
const [id, metadata] = await this.authSocket(socket, tempId);
s[kMetadata] = metadata;
let client = this._sockets.get(id);
if (!client) {
logger('[%s] new connection created', id);
client = new RpcServerSocket(id, this);
this._sockets.set(id, client);
} else {
logger('[%s] connection resumed', id);
const tid = this.disconnectingSockets.get(id);
if (tid != null) {
this.disconnectingSockets.delete(id);
clearTimeout(tid);
}
}
await client[kReplaceSocket](s);
return client;
}
/** 认证客户端 */
protected async authSocket(socket: WebSocket, tempId: string): Promise<[ConnectionID, RpcMetadata]> {
let seq, id, remoteMetadata;
try {
logger('[%s] authenticating...', tempId);
[seq, id, remoteMetadata] = await waitAuth(socket);
logger('[%s -> %s] got server auth. remoteMeta=%o', tempId, id, remoteMetadata);
const localMetadata = await this.authenticator(remoteMetadata);
logger('[%s] server auth success. localMeta=%o', id, localMetadata);
send(socket, 'auth', {
seq,
id,
version: VERSION,
metadata: localMetadata,
});
return [id, remoteMetadata];
} catch (ex) {
logger('[%s] server auth failed. remoteMeta=%o, error=%o', id ?? tempId, remoteMetadata, ex);
if (seq != null && id != null) {
send(socket, 'auth', {
seq,
id,
version: VERSION,
metadata: {},
error: serializeError(ex),
});
} else {
send(socket, 'error', {
seq: 1,
error: serializeError(ex),
});
}
socket.close(WebSocketAppCode.AUTH_ERROR);
throw ex;
}
}
private readonly disconnectingSockets = new Map<ConnectionID, ReturnType<typeof setTimeout>>();
/** WebSocket 断开时调用 */
[kOnClose](socket: RpcServerSocket<TRemote, TLocal>): void {
const delay = socket.destroyed ? 0 : 5000;
logger('[%s] socket closed, drop connection in %dms', socket.id, delay);
const oldTid = this.disconnectingSockets.get(socket.id);
if (oldTid != null) {
clearTimeout(oldTid);
}
const tid = setTimeout(() => {
logger('[%s] connection dropped', socket.id);
socket.destroy();
this._sockets.delete(socket.id);
this.disconnectingSockets.delete(socket.id);
}, delay);
this.disconnectingSockets.set(socket.id, tid);
}
}
/** 由 WS Server 建立的 RPC 连接 */
export class RpcServerSocket<TRemote extends object, TLocal extends object> extends RpcSocket<TRemote, TLocal> {
constructor(
id: ConnectionID,
readonly server: RpcServer<TRemote, TLocal>,
) {
super(id);
this.seq = 1;
}
/** @inheritdoc */
// eslint-disable-next-line @typescript-eslint/promise-function-async
protected authSocket(): Promise<RpcMetadata> {
return Promise.resolve((this.socket as WebSocketWithMetadata)[kMetadata]);
}
/** @inheritdoc */
protected override get local(): RpcObject<TLocal> | undefined {
return this.server.local?.(this);
}
/** @inheritdoc */
protected override onClose(ev: CloseEvent): void {
super.onClose(ev);
this.server[kOnClose](this);
}
/** 替换 socket */
async [kReplaceSocket](newSocket: WebSocketWithMetadata): Promise<void> {
this.socket = newSocket;
return this.ready;
}
}