UNPKG

echogarden

Version:

An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.

183 lines (130 loc) 5.53 kB
import { decibelsToGainFactor, RawAudio } from '../audio/AudioUtilities.js' import type * as Onnx from 'onnxruntime-node' import { getOnnxSessionOptions, OnnxExecutionProvider } from '../utilities/OnnxUtilities.js' import { readdir } from '../utilities/FileSystem.js' import { joinPath } from '../utilities/PathUtilities.js' import { stftr, stiftr } from '../dsp/FFT.js' import { clip, concatFloat32Arrays } from '../utilities/Utilities.js' import { Logger } from '../utilities/Logger.js' export async function denoiseAudio(rawAudio: RawAudio, options: NSNet2Options) { const onnxExecutionProviders: OnnxExecutionProvider[] = options.provider ? [options.provider] : []//['dml', 'cpu'] const denoiser = new NSNet2(options.model!, options.modelDirectoryPath!, onnxExecutionProviders, options.maxAttenuation!) const result = await denoiser.denoiseAudio(rawAudio) return result } export class NSNet2 { session?: Onnx.InferenceSession constructor( public readonly modelName: NSNet2ModelName, public readonly modelDirectoryPath: string, public readonly executionProviders: OnnxExecutionProvider[], public readonly maxAttenuation: number) { } async denoiseAudio(rawAudio: RawAudio) { const logger = new Logger() logger.start(`Initialize ONNX model ${this.modelName}`) await this.initializeIfNeeded() let fftSize: number if (this.modelName === 'baseline-48khz') { fftSize = 1024 if (rawAudio.sampleRate !== 48000) { throw new Error(`Denoising model baseline-48khz requires a 48000 Hz signal`) } } else if (this.modelName === 'baseline-16khz') { fftSize = 320 if (rawAudio.sampleRate !== 16000) { throw new Error(`Denoising model baseline-16khz requires a 16000 Hz signal`) } } else { throw new Error(`Unsupported model name: ${this.modelName}`) } const fftHopSize = fftSize / 2 const fftRealBinCount = (fftSize / 2) + 1 logger.start('Compute STFT frames') const stftrFrames = await stftr(rawAudio.audioChannels[0], fftSize, fftSize, fftHopSize, 'hann') logger.start('Compute log-power spectrogram') let logPowerSpectrogram: Float32Array[] = [] { for (const frame of stftrFrames) { const logPowerSpectrum = new Float32Array(frame.length / 2) let readOffset = 0 let writeOffset = 0 while (readOffset < frame.length) { const real = frame[readOffset++] const imaginary = frame[readOffset++] const powerValue = (real ** 2) + (imaginary ** 2) const clampedPowerValue = Math.max(powerValue, 1e-12) const logPowerValue = Math.log10(clampedPowerValue) logPowerSpectrum[writeOffset++] = logPowerValue } logPowerSpectrogram.push(logPowerSpectrum) } } logger.start('Process log-power spectrogram using ONNX model') const frameCount = logPowerSpectrogram.length let flattenedOutputTensor: Float32Array { const Onnx = await import('onnxruntime-node') const flattenedFeatures = concatFloat32Arrays(logPowerSpectrogram) const inputTensor = new Onnx.Tensor('float32', flattenedFeatures, [1, frameCount, fftRealBinCount]) const inputs = { input: inputTensor } const result = await this.session!.run(inputs) flattenedOutputTensor = result.output.data as Float32Array } { logger.start('Apply model output as a filter to original STFT frames') const fftSizeReciprocal = 1 / fftSize const minGainRatio = decibelsToGainFactor(-this.maxAttenuation) const maxGainRatio = 1.0 let flattenenedOutputTensorReadIndex = 0 for (let frameIndex = 0; frameIndex < frameCount; frameIndex++) { const frame = stftrFrames[frameIndex] let frameReadIndex = 0 for (let binIndex = 0; binIndex < fftRealBinCount; binIndex++) { let gainRatio = flattenedOutputTensor[flattenenedOutputTensorReadIndex++] gainRatio = clip(gainRatio, minGainRatio, maxGainRatio) frame[frameReadIndex++] *= gainRatio * fftSizeReciprocal frame[frameReadIndex++] *= gainRatio * fftSizeReciprocal } } } // Allow logPowerSpectrogram to be garbage collected logPowerSpectrogram = undefined as any logger.start('Reconstruct filtered signal using inverse STFT') const filteredSignal = await stiftr(stftrFrames, fftSize, fftSize, fftHopSize, 'hann') const denoisedAudio: RawAudio = { audioChannels: [filteredSignal], sampleRate: rawAudio.sampleRate } logger.end() return { denoisedAudio } } private async initializeIfNeeded() { if (this.session) { return } const filesInModelPath = await readdir(this.modelDirectoryPath) const onnxModelFilename = filesInModelPath.find(filename => filename.endsWith('.onnx')) if (!onnxModelFilename) { throw new Error(`Couldn't file any ONNX model file in ${this.modelDirectoryPath}`) } const onnxModelPath = joinPath(this.modelDirectoryPath, onnxModelFilename) const Onnx = await import('onnxruntime-node') const onnxSessionOptions = getOnnxSessionOptions({ executionProviders: this.executionProviders }) this.session = await Onnx.InferenceSession.create(onnxModelPath, onnxSessionOptions) } } //export type NSNet2ModelName = 'nsnet2-20ms-baseline' | 'nsnet2-20ms-48k-baseline' export type NSNet2ModelName = 'baseline-16khz' | 'baseline-48khz' export const defaultNSNet2Options: NSNet2Options = { model: 'baseline-48khz', modelDirectoryPath: undefined, provider: undefined, maxAttenuation: 30, } export interface NSNet2Options { model?: NSNet2ModelName modelDirectoryPath?: string provider?: OnnxExecutionProvider maxAttenuation?: number }