UNPKG

web-voice-detection

Version:

A WebAssembly-powered Voice Activity Detection library for the browser.

71 lines (58 loc) 1.88 kB
import { InferenceSession, Tensor } from "onnxruntime-web"; import { log } from "./logging"; export type ONNXRuntimeAPI = typeof import("onnxruntime-web"); export type ModelFetcher = () => Promise<ArrayBuffer>; export type OrtOptions = { ortConfig?: (ort: ONNXRuntimeAPI) => any; }; export type SpeechProbabilities = { notSpeech: number; isSpeech: number; }; export type Model = { reset_state: () => void; process: (arr: Float32Array) => Promise<SpeechProbabilities>; }; export class ONNXModel { _session!: InferenceSession; _h!: Tensor; _c!: Tensor; _sr!: Tensor; constructor( private ort: ONNXRuntimeAPI, private modelFetcher: ModelFetcher, ) {} static new = async (ort: ONNXRuntimeAPI, modelFetcher: ModelFetcher) => { const model = new ONNXModel(ort, modelFetcher); await model.init(); return model; }; init = async () => { log.debug("initializing detection"); const modelArrayBuffer = await this.modelFetcher(); this._session = await this.ort.InferenceSession.create(modelArrayBuffer); this._sr = new this.ort.Tensor("int64", [16000n]); this.reset_state(); log.debug("detection is initialized"); }; reset_state = () => { const zeroes = Array(2 * 64).fill(0); this._h = new this.ort.Tensor("float32", zeroes, [2, 1, 64]); this._c = new this.ort.Tensor("float32", zeroes, [2, 1, 64]); }; process = async (audioFrame: Float32Array): Promise<SpeechProbabilities> => { const t = new this.ort.Tensor("float32", audioFrame, [1, audioFrame.length]); const inputs = { input: t, h: this._h, c: this._c, sr: this._sr, }; const out = await this._session.run(inputs); this._h = out.hn; this._c = out.cn; const isSpeech = out.output.data[0] as number; const notSpeech = 1 - isSpeech; return { notSpeech, isSpeech }; }; }