stt-sdk
Version:
基于 LLMs 的语音转文本 SDK
287 lines (239 loc) • 6.67 kB
text/typescript
import { BaseSTTClientUtil } from "./BaseSTTClientUtil.js";
export enum BaseSTTClientHookType {
TEXT_RECEIVE = "TEXT_RECEIVE",
CONNECT_ERROR_HANDLE = "CONNECT_ERROR_HANDLE",
CONNECT_CLOSE_HANDLE = "CONNECT_CLOSE_HANDLE",
DEVICE_OPEN_FAIL_HANDLE = "DEVICE_OPEN_FAIL_HANDLE"
}
export type HookFunctionType = (...arg: any[]) => Promise<any>;
export const STT_VOICE_STOP = "voice-stop";
/**
* STS 客户端接口
*/
export interface BaseSTTClientInterface {
/**
* 连接初始化处理
*/
onConnectInit: () => void
/**
* 连接错误处理
*/
onConnectError: () => void
/**
* 连接关闭处理
*/
onConnectClose: () => void
/**
* 连接打开处理
*/
onConnectOpen: () => void
/**
* 文本接收处理
*/
onTextReceive: () => void
/**
* hook注册
*/
addHook: <R>(type: BaseSTTClientHookType, func: (...arg: any[]) => Promise<R>) => void
/**
* 关闭连接
*/
closeConnection: () => void
}
/**
* STS 客户端抽象类
*/
export abstract class BaseSTTClient<T> extends BaseSTTClientUtil implements BaseSTTClientInterface {
/**
* 必须注册的 hook
*/
private static readonly requiredHookNames = [
BaseSTTClientHookType.TEXT_RECEIVE
];
/**
* hook 收集器
*/
private readonly HooksMap: { [key: string]: HookFunctionType } = {};
/**
* 临时 token
*/
private access_token: string | null = null;
/**
* 客户端配置
*/
private client_config: T | null = null;
/**
* 转录模式
*/
private mode: "rt" | "all" = "rt";
/**
* 文本收集器
*/
private textCollectionList: string[] = [];
/**
* 最近使用的监听器
*/
private latestListener: (() => void) | null = null;
/**
* 校验所有必须的 hook 是否都已经注册
*/
public checkRequiredHookIsAllRegister(): void {
BaseSTTClient.requiredHookNames.forEach(key => {
if (!this.getHook(key)) throw new Error(`${key} => Hook未注册`)
});
}
/**
* 连接打开处理器
*/
public connectOpenHandler(stream_slice_handler: (data: Blob) => void) {
const codec = this.pickSupportedMime();
(async () => {
let stream: MediaStream | null = null;
try {
stream = await this.defaultVoiceStream();
if (this.latestListener) {
window.removeEventListener(STT_VOICE_STOP, this.latestListener);
}
this.latestListener = () => {
this.voiceStopHandler(stream!);
}
window.addEventListener(STT_VOICE_STOP, this.latestListener);
} catch (e) {
const handler = this.getHook(BaseSTTClientHookType.DEVICE_OPEN_FAIL_HANDLE);
handler && handler(e);
return;
}
this.defaultSlice(stream, stream_slice_handler, { mime: codec, sliceTimeout: 250 });
})();
}
/**
* 停止音频流
*/
private voiceStopHandler(stream: MediaStream) {
const tracks = stream!.getTracks(); // 获取所有轨道
tracks.forEach(track => track.stop()); // 停止每个轨道
if (this.mode === "all") {
this.getHook(BaseSTTClientHookType.TEXT_RECEIVE)(this.textCollectionList.join(""));
}
}
/**
* 文本接收处理器
*/
public textReceiveHandler(text: string): void {
if (this.mode === "all") {
this.textCollectionList.push(text);
return;
}
this.getHook(BaseSTTClientHookType.TEXT_RECEIVE)(text);
}
/**
* 连接关闭处理器
*/
public connectCloseHandler(e: any): void {
const handler = this.getHook(BaseSTTClientHookType.CONNECT_CLOSE_HANDLE);
handler && handler(e);
}
/**
* 连接错误处理器
*/
public connectErrorHandler(e: any): void {
const hanlder = this.getHook(BaseSTTClientHookType.CONNECT_ERROR_HANDLE);
hanlder && hanlder(e);
}
/**
* hook getter
*/
public getHook(key: BaseSTTClientHookType): HookFunctionType {
return this.HooksMap[key]!;
}
/**
* hook setter
*/
public setHook(key: BaseSTTClientHookType, func: HookFunctionType): void {
this.HooksMap[key] = func;
}
/**
* access_token getter
*/
public getAccessToken(): string | null {
return this.access_token;
}
/**
* access_token setter
*/
public setAccessToken(token: string) {
this.access_token = token;
}
/**
* client_config getter
*/
public getClientConfig(): T | null {
return this.client_config;
}
/**
* client_config setter
*/
public setClientConfig(config: T) {
this.client_config = config;
}
/**
* mode setter
*/
public setMode(mode: "rt" | "all"): void {
this.mode = mode;
}
/**
* mode getter
*/
public getMode(): "rt" | "all" {
return this.mode;
}
/**
* 连接初始化处理
*/
public onConnectInit(): void {
throw new Error("onConnectInit 未实现");
}
/**
* 连接错误处理
*/
public onConnectError(): void {
throw new Error("onConnectError 未实现");
}
/**
* 连接关闭处理
*/
public onConnectClose(): void {
throw new Error("onConnectClose 未实现");
}
/**
* 连接打开处理
*/
public onConnectOpen(): void {
throw new Error("onConnectOpen 未实现");
}
/**
* 文本接收处理
*/
public onTextReceive(): void {
throw new Error("onTextReceive 未实现");
}
/**
* hook注册
*/
public addHook<R>(type: BaseSTTClientHookType, func: (...arg: any[]) => Promise<R>): void {
throw new Error("addHook 未实现");
}
/**
* 关闭连接
*/
public closeConnection() {
throw new Error("closeConnection 未实现");
}
/**
* 流失传输结束处理
*/
public onEndOfStream() {
throw new Error("onEndOfStream 未实现");
}
}