echogarden
Version:
An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.
128 lines • 6 kB
JavaScript
import { decibelsToGainFactor } from '../audio/AudioUtilities.js';
import { getOnnxSessionOptions } from '../utilities/OnnxUtilities.js';
import { readdir } from '../utilities/FileSystem.js';
import { joinPath } from '../utilities/PathUtilities.js';
import { stftr, stiftr } from '../dsp/FFT.js';
import { clip, concatFloat32Arrays } from '../utilities/Utilities.js';
import { Logger } from '../utilities/Logger.js';
export async function denoiseAudio(rawAudio, options) {
const onnxExecutionProviders = options.provider ? [options.provider] : []; //['dml', 'cpu']
const denoiser = new NSNet2(options.model, options.modelDirectoryPath, onnxExecutionProviders, options.maxAttenuation);
const result = await denoiser.denoiseAudio(rawAudio);
return result;
}
export class NSNet2 {
modelName;
modelDirectoryPath;
executionProviders;
maxAttenuation;
session;
constructor(modelName, modelDirectoryPath, executionProviders, maxAttenuation) {
this.modelName = modelName;
this.modelDirectoryPath = modelDirectoryPath;
this.executionProviders = executionProviders;
this.maxAttenuation = maxAttenuation;
}
async denoiseAudio(rawAudio) {
const logger = new Logger();
logger.start(`Initialize ONNX model ${this.modelName}`);
await this.initializeIfNeeded();
let fftSize;
if (this.modelName === 'baseline-48khz') {
fftSize = 1024;
if (rawAudio.sampleRate !== 48000) {
throw new Error(`Denoising model baseline-48khz requires a 48000 Hz signal`);
}
}
else if (this.modelName === 'baseline-16khz') {
fftSize = 320;
if (rawAudio.sampleRate !== 16000) {
throw new Error(`Denoising model baseline-16khz requires a 16000 Hz signal`);
}
}
else {
throw new Error(`Unsupported model name: ${this.modelName}`);
}
const fftHopSize = fftSize / 2;
const fftRealBinCount = (fftSize / 2) + 1;
logger.start('Compute STFT frames');
const stftrFrames = await stftr(rawAudio.audioChannels[0], fftSize, fftSize, fftHopSize, 'hann');
logger.start('Compute log-power spectrogram');
let logPowerSpectrogram = [];
{
for (const frame of stftrFrames) {
const logPowerSpectrum = new Float32Array(frame.length / 2);
let readOffset = 0;
let writeOffset = 0;
while (readOffset < frame.length) {
const real = frame[readOffset++];
const imaginary = frame[readOffset++];
const powerValue = (real ** 2) + (imaginary ** 2);
const clampedPowerValue = Math.max(powerValue, 1e-12);
const logPowerValue = Math.log10(clampedPowerValue);
logPowerSpectrum[writeOffset++] = logPowerValue;
}
logPowerSpectrogram.push(logPowerSpectrum);
}
}
logger.start('Process log-power spectrogram using ONNX model');
const frameCount = logPowerSpectrogram.length;
let flattenedOutputTensor;
{
const Onnx = await import('onnxruntime-node');
const flattenedFeatures = concatFloat32Arrays(logPowerSpectrogram);
const inputTensor = new Onnx.Tensor('float32', flattenedFeatures, [1, frameCount, fftRealBinCount]);
const inputs = { input: inputTensor };
const result = await this.session.run(inputs);
flattenedOutputTensor = result.output.data;
}
{
logger.start('Apply model output as a filter to original STFT frames');
const fftSizeReciprocal = 1 / fftSize;
const minGainRatio = decibelsToGainFactor(-this.maxAttenuation);
const maxGainRatio = 1.0;
let flattenenedOutputTensorReadIndex = 0;
for (let frameIndex = 0; frameIndex < frameCount; frameIndex++) {
const frame = stftrFrames[frameIndex];
let frameReadIndex = 0;
for (let binIndex = 0; binIndex < fftRealBinCount; binIndex++) {
let gainRatio = flattenedOutputTensor[flattenenedOutputTensorReadIndex++];
gainRatio = clip(gainRatio, minGainRatio, maxGainRatio);
frame[frameReadIndex++] *= gainRatio * fftSizeReciprocal;
frame[frameReadIndex++] *= gainRatio * fftSizeReciprocal;
}
}
}
// Allow logPowerSpectrogram to be garbage collected
logPowerSpectrogram = undefined;
logger.start('Reconstruct filtered signal using inverse STFT');
const filteredSignal = await stiftr(stftrFrames, fftSize, fftSize, fftHopSize, 'hann');
const denoisedAudio = {
audioChannels: [filteredSignal],
sampleRate: rawAudio.sampleRate
};
logger.end();
return { denoisedAudio };
}
async initializeIfNeeded() {
if (this.session) {
return;
}
const filesInModelPath = await readdir(this.modelDirectoryPath);
const onnxModelFilename = filesInModelPath.find(filename => filename.endsWith('.onnx'));
if (!onnxModelFilename) {
throw new Error(`Couldn't file any ONNX model file in ${this.modelDirectoryPath}`);
}
const onnxModelPath = joinPath(this.modelDirectoryPath, onnxModelFilename);
const Onnx = await import('onnxruntime-node');
const onnxSessionOptions = getOnnxSessionOptions({ executionProviders: this.executionProviders });
this.session = await Onnx.InferenceSession.create(onnxModelPath, onnxSessionOptions);
}
}
export const defaultNSNet2Options = {
model: 'baseline-48khz',
modelDirectoryPath: undefined,
provider: undefined,
maxAttenuation: 30,
};
//# sourceMappingURL=NSNet2.js.map