echogarden
Version:
An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.
249 lines • 12.6 kB
JavaScript
import { getEmptyRawAudio } from '../audio/AudioUtilities.js';
import { getWindowWeights, createStftrGenerator, stiftr } from '../dsp/FFT.js';
import { logToStderr } from '../utilities/Utilities.js';
import { Logger } from '../utilities/Logger.js';
import { dmlProviderAvailable, getOnnxSessionOptions } from '../utilities/OnnxUtilities.js';
import chalk from 'chalk';
import { WindowedList } from '../data-structures/WindowedList.js';
import { logLevelGreaterOrEqualTo } from '../api/API.js';
const log = logToStderr;
export async function isolate(rawAudio, modelFilePath, modelProfile, options) {
const model = new MDXNet(modelFilePath, modelProfile, options);
return model.processAudio(rawAudio);
}
export class MDXNet {
modelFilePath;
modelProfile;
options;
session;
onnxSessionOptions;
constructor(modelFilePath, modelProfile, options) {
this.modelFilePath = modelFilePath;
this.modelProfile = modelProfile;
this.options = options;
}
async processAudio(rawAudio) {
if (rawAudio.audioChannels.length !== 2) {
throw new Error(`Input audio must be stereo`);
}
if (rawAudio.sampleRate !== this.modelProfile.sampleRate) {
throw new Error(`Input audio must have a sample rate of ${this.modelProfile.sampleRate} Hz`);
}
if (rawAudio.audioChannels[0].length === 0) {
return getEmptyRawAudio(rawAudio.audioChannels.length, rawAudio.sampleRate);
}
const enableTraceLogging = logLevelGreaterOrEqualTo('trace');
const logger = new Logger();
await logger.startAsync(`Initialize session for MDX-NET model '${this.options.model}'`);
await this.initializeSessionIfNeeded();
logger.end();
logger.logTitledMessage(`Using ONNX execution provider`, `${this.onnxSessionOptions.executionProviders.join(', ')}`);
const Onnx = await import('onnxruntime-node');
const sampleRate = this.modelProfile.sampleRate;
const fftSize = this.modelProfile.fftSize;
const fftWindowSize = this.modelProfile.fftWindowSize;
const fftHopSize = this.modelProfile.fftHopSize;
const fftWindowType = this.modelProfile.fftWindowType;
const binCount = this.modelProfile.binCount;
const segmentSize = this.modelProfile.segmentSize;
const segmentHopSize = this.modelProfile.segmentHopSize;
const sampleCount = rawAudio.audioChannels[0].length;
const fftSizeReciprocal = 1 / fftSize;
// Initialize generators for STFT frames for each channel
const fftFramesLeftGenerator = await createStftrGenerator(rawAudio.audioChannels[0], fftSize, fftWindowSize, fftHopSize, fftWindowType);
const fftFramesRightGenerator = await createStftrGenerator(rawAudio.audioChannels[1], fftSize, fftWindowSize, fftHopSize, fftWindowType);
// Initial windowed lists to store recently computed STFT frames
const fftFramesLeftWindowedList = new WindowedList(segmentSize);
const fftFramesRightWindowedList = new WindowedList(segmentSize);
const audioForSegments = [];
for (let segmentStartFrameOffset = 0;; segmentStartFrameOffset += segmentHopSize) {
const segmentEndFrameOffset = segmentStartFrameOffset + segmentSize;
const timePosition = segmentStartFrameOffset * (fftHopSize / sampleRate);
if (enableTraceLogging) {
await logger.startAsync(`Compute STFT of segment at time position ${timePosition.toFixed(2)}`, undefined, chalk.magentaBright);
}
else {
await logger.startAsync(`Process segment at time position ${timePosition.toFixed(2)}`);
}
while (fftFramesLeftWindowedList.endOffset < segmentEndFrameOffset) {
const nextLeftFrameResult = fftFramesLeftGenerator.next();
if (nextLeftFrameResult.done) {
break;
}
const nextRightFrameResult = fftFramesRightGenerator.next();
if (nextRightFrameResult.done) {
break;
}
fftFramesLeftWindowedList.add(nextLeftFrameResult.value);
fftFramesRightWindowedList.add(nextRightFrameResult.value);
}
const fftFramesForSegment = [
fftFramesLeftWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset),
fftFramesRightWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset)
];
const segmentLength = fftFramesForSegment[0].length;
const isLastSegment = segmentLength < segmentSize;
if (enableTraceLogging) {
await logger.startAsync(`Reshape STFT frames`);
}
const flattenedInputTensor = new Float32Array(1 * 4 * binCount * segmentSize);
{
let writePosition = 0;
// 4 tensor elements are structured as:
// <Channel 0 real> <Channel 0 imaginary> <Channel 1 real> <Channel 1 imaginary>
for (let tensorElementIndex = 0; tensorElementIndex < 4; tensorElementIndex++) {
const isRealComponentTensorElementIndex = tensorElementIndex % 2 === 0;
const audioChannelIndex = tensorElementIndex < 2 ? 0 : 1;
for (let binIndex = 0; binIndex < binCount; binIndex++) {
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
let value = 0;
if (frameIndex < segmentLength) {
const frame = fftFramesForSegment[audioChannelIndex][frameIndex];
if (isRealComponentTensorElementIndex) {
value = frame[binIndex << 1];
}
else {
value = frame[(binIndex << 1) + 1];
}
}
flattenedInputTensor[writePosition++] = value;
}
}
}
}
if (enableTraceLogging) {
await logger.startAsync(`Process segment with MDXNet model`);
}
const inputTensor = new Onnx.Tensor('float32', flattenedInputTensor, [1, 4, binCount, segmentSize]);
const { output: outputTensor } = await this.session.run({ input: inputTensor });
if (enableTraceLogging) {
await logger.startAsync('Reshape processed frames');
}
const flattenedOutputTensor = outputTensor.data;
const outputSegmentFramesForChannel = [];
{
for (let outChannelIndex = 0; outChannelIndex < 2; outChannelIndex++) {
const framesForChannel = [];
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
const frame = new Float32Array(fftSize);
framesForChannel.push(frame);
}
outputSegmentFramesForChannel.push(framesForChannel);
}
let readPosition = 0;
for (let tensorChannelIndex = 0; tensorChannelIndex < 4; tensorChannelIndex++) {
const isRealTensorChannelIndex = tensorChannelIndex % 2 === 0;
const audioChannelIndex = tensorChannelIndex < 2 ? 0 : 1;
const framesForOutputChannel = outputSegmentFramesForChannel[audioChannelIndex];
for (let binIndex = 0; binIndex < binCount; binIndex++) {
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
const outFrame = framesForOutputChannel[frameIndex];
if (isRealTensorChannelIndex) {
outFrame[binIndex << 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal;
}
else {
outFrame[(binIndex << 1) + 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal;
}
}
}
}
}
const outputAudioChannels = [];
if (enableTraceLogging) {
await logger.startAsync(`Compute inverse STFT of model output for segment`);
}
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
const samples = await stiftr(outputSegmentFramesForChannel[channelIndex], fftSize, fftWindowSize, fftHopSize, fftWindowType);
outputAudioChannels.push(samples);
}
audioForSegments.push(outputAudioChannels);
if (isLastSegment) {
break;
}
}
// Join segments using overlapping Hann windows
await logger.startAsync(`Join segments`);
const joinedSegments = [new Float32Array(sampleCount), new Float32Array(sampleCount)];
{
const segmentCount = audioForSegments.length;
const segmentSampleCount = audioForSegments[0][0].length;
const windowWeights = getWindowWeights('hann', segmentSampleCount);
const sumOfWeightsForSample = new Float32Array(sampleCount);
for (let segmentIndex = 0; segmentIndex < segmentCount; segmentIndex++) {
const segmentStartFrameIndex = segmentIndex * segmentHopSize;
const segmentStartSampleIndex = segmentStartFrameIndex * fftHopSize;
const segmentSamples = audioForSegments[segmentIndex];
for (let segmentSampleOffset = 0; segmentSampleOffset < segmentSampleCount; segmentSampleOffset++) {
const sampleIndex = segmentStartSampleIndex + segmentSampleOffset;
if (sampleIndex >= sampleCount) {
break;
}
const weight = windowWeights[segmentSampleOffset];
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
joinedSegments[channelIndex][sampleIndex] += segmentSamples[channelIndex][segmentSampleOffset] * weight;
}
sumOfWeightsForSample[sampleIndex] += weight;
}
}
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) {
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
joinedSegments[channelIndex][sampleIndex] /= sumOfWeightsForSample[sampleIndex] + 1e-8;
}
}
}
const isolatedRawAudio = { audioChannels: joinedSegments, sampleRate };
logger.end();
return isolatedRawAudio;
}
async initializeSessionIfNeeded() {
if (this.session) {
return;
}
const Onnx = await import('onnxruntime-node');
const executionProviders = this.options.provider ? [this.options.provider] : getDefaultMDXNetProviders();
this.onnxSessionOptions = getOnnxSessionOptions({ executionProviders });
this.session = await Onnx.InferenceSession.create(this.modelFilePath, this.onnxSessionOptions);
}
}
export function getDefaultMDXNetProviders() {
if (dmlProviderAvailable()) {
return ['dml', 'cpu'];
}
else {
return [];
}
}
export function getProfileForMDXNetModelName(modelName) {
if (['UVR_MDXNET_1_9703', 'UVR_MDXNET_2_9682', 'UVR_MDXNET_3_9662', 'UVR_MDXNET_KARA'].includes(modelName)) {
return mdxNetModelProfile1;
}
if (['UVR_MDXNET_Main', 'Kim_Vocal_1', 'Kim_Vocal_2'].includes(modelName)) {
return mdxNetModelProfile2;
}
throw new Error(`Unsupported model name: '${modelName}'`);
}
export const mdxNetModelProfile1 = {
sampleRate: 44100,
fftSize: 6144,
fftWindowSize: 6144,
fftHopSize: 1024,
fftWindowType: 'hann',
binCount: 2048,
segmentSize: 256,
segmentHopSize: 224,
};
export const mdxNetModelProfile2 = {
sampleRate: 44100,
fftSize: 7680,
fftWindowSize: 7680,
fftHopSize: 1024,
fftWindowType: 'hann',
binCount: 3072,
segmentSize: 256,
segmentHopSize: 224,
};
export const defaultMDXNetOptions = {
model: 'UVR_MDXNET_1_9703',
provider: undefined,
};
//# sourceMappingURL=MDXNetSourceSeparation.js.map