UNPKG

echogarden

Version:

An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.

249 lines 12.6 kB
import { getEmptyRawAudio } from '../audio/AudioUtilities.js'; import { getWindowWeights, createStftrGenerator, stiftr } from '../dsp/FFT.js'; import { logToStderr } from '../utilities/Utilities.js'; import { Logger } from '../utilities/Logger.js'; import { dmlProviderAvailable, getOnnxSessionOptions } from '../utilities/OnnxUtilities.js'; import chalk from 'chalk'; import { WindowedList } from '../data-structures/WindowedList.js'; import { logLevelGreaterOrEqualTo } from '../api/API.js'; const log = logToStderr; export async function isolate(rawAudio, modelFilePath, modelProfile, options) { const model = new MDXNet(modelFilePath, modelProfile, options); return model.processAudio(rawAudio); } export class MDXNet { modelFilePath; modelProfile; options; session; onnxSessionOptions; constructor(modelFilePath, modelProfile, options) { this.modelFilePath = modelFilePath; this.modelProfile = modelProfile; this.options = options; } async processAudio(rawAudio) { if (rawAudio.audioChannels.length !== 2) { throw new Error(`Input audio must be stereo`); } if (rawAudio.sampleRate !== this.modelProfile.sampleRate) { throw new Error(`Input audio must have a sample rate of ${this.modelProfile.sampleRate} Hz`); } if (rawAudio.audioChannels[0].length === 0) { return getEmptyRawAudio(rawAudio.audioChannels.length, rawAudio.sampleRate); } const enableTraceLogging = logLevelGreaterOrEqualTo('trace'); const logger = new Logger(); await logger.startAsync(`Initialize session for MDX-NET model '${this.options.model}'`); await this.initializeSessionIfNeeded(); logger.end(); logger.logTitledMessage(`Using ONNX execution provider`, `${this.onnxSessionOptions.executionProviders.join(', ')}`); const Onnx = await import('onnxruntime-node'); const sampleRate = this.modelProfile.sampleRate; const fftSize = this.modelProfile.fftSize; const fftWindowSize = this.modelProfile.fftWindowSize; const fftHopSize = this.modelProfile.fftHopSize; const fftWindowType = this.modelProfile.fftWindowType; const binCount = this.modelProfile.binCount; const segmentSize = this.modelProfile.segmentSize; const segmentHopSize = this.modelProfile.segmentHopSize; const sampleCount = rawAudio.audioChannels[0].length; const fftSizeReciprocal = 1 / fftSize; // Initialize generators for STFT frames for each channel const fftFramesLeftGenerator = await createStftrGenerator(rawAudio.audioChannels[0], fftSize, fftWindowSize, fftHopSize, fftWindowType); const fftFramesRightGenerator = await createStftrGenerator(rawAudio.audioChannels[1], fftSize, fftWindowSize, fftHopSize, fftWindowType); // Initial windowed lists to store recently computed STFT frames const fftFramesLeftWindowedList = new WindowedList(segmentSize); const fftFramesRightWindowedList = new WindowedList(segmentSize); const audioForSegments = []; for (let segmentStartFrameOffset = 0;; segmentStartFrameOffset += segmentHopSize) { const segmentEndFrameOffset = segmentStartFrameOffset + segmentSize; const timePosition = segmentStartFrameOffset * (fftHopSize / sampleRate); if (enableTraceLogging) { await logger.startAsync(`Compute STFT of segment at time position ${timePosition.toFixed(2)}`, undefined, chalk.magentaBright); } else { await logger.startAsync(`Process segment at time position ${timePosition.toFixed(2)}`); } while (fftFramesLeftWindowedList.endOffset < segmentEndFrameOffset) { const nextLeftFrameResult = fftFramesLeftGenerator.next(); if (nextLeftFrameResult.done) { break; } const nextRightFrameResult = fftFramesRightGenerator.next(); if (nextRightFrameResult.done) { break; } fftFramesLeftWindowedList.add(nextLeftFrameResult.value); fftFramesRightWindowedList.add(nextRightFrameResult.value); } const fftFramesForSegment = [ fftFramesLeftWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset), fftFramesRightWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset) ]; const segmentLength = fftFramesForSegment[0].length; const isLastSegment = segmentLength < segmentSize; if (enableTraceLogging) { await logger.startAsync(`Reshape STFT frames`); } const flattenedInputTensor = new Float32Array(1 * 4 * binCount * segmentSize); { let writePosition = 0; // 4 tensor elements are structured as: // <Channel 0 real> <Channel 0 imaginary> <Channel 1 real> <Channel 1 imaginary> for (let tensorElementIndex = 0; tensorElementIndex < 4; tensorElementIndex++) { const isRealComponentTensorElementIndex = tensorElementIndex % 2 === 0; const audioChannelIndex = tensorElementIndex < 2 ? 0 : 1; for (let binIndex = 0; binIndex < binCount; binIndex++) { for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { let value = 0; if (frameIndex < segmentLength) { const frame = fftFramesForSegment[audioChannelIndex][frameIndex]; if (isRealComponentTensorElementIndex) { value = frame[binIndex << 1]; } else { value = frame[(binIndex << 1) + 1]; } } flattenedInputTensor[writePosition++] = value; } } } } if (enableTraceLogging) { await logger.startAsync(`Process segment with MDXNet model`); } const inputTensor = new Onnx.Tensor('float32', flattenedInputTensor, [1, 4, binCount, segmentSize]); const { output: outputTensor } = await this.session.run({ input: inputTensor }); if (enableTraceLogging) { await logger.startAsync('Reshape processed frames'); } const flattenedOutputTensor = outputTensor.data; const outputSegmentFramesForChannel = []; { for (let outChannelIndex = 0; outChannelIndex < 2; outChannelIndex++) { const framesForChannel = []; for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { const frame = new Float32Array(fftSize); framesForChannel.push(frame); } outputSegmentFramesForChannel.push(framesForChannel); } let readPosition = 0; for (let tensorChannelIndex = 0; tensorChannelIndex < 4; tensorChannelIndex++) { const isRealTensorChannelIndex = tensorChannelIndex % 2 === 0; const audioChannelIndex = tensorChannelIndex < 2 ? 0 : 1; const framesForOutputChannel = outputSegmentFramesForChannel[audioChannelIndex]; for (let binIndex = 0; binIndex < binCount; binIndex++) { for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { const outFrame = framesForOutputChannel[frameIndex]; if (isRealTensorChannelIndex) { outFrame[binIndex << 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal; } else { outFrame[(binIndex << 1) + 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal; } } } } } const outputAudioChannels = []; if (enableTraceLogging) { await logger.startAsync(`Compute inverse STFT of model output for segment`); } for (let channelIndex = 0; channelIndex < 2; channelIndex++) { const samples = await stiftr(outputSegmentFramesForChannel[channelIndex], fftSize, fftWindowSize, fftHopSize, fftWindowType); outputAudioChannels.push(samples); } audioForSegments.push(outputAudioChannels); if (isLastSegment) { break; } } // Join segments using overlapping Hann windows await logger.startAsync(`Join segments`); const joinedSegments = [new Float32Array(sampleCount), new Float32Array(sampleCount)]; { const segmentCount = audioForSegments.length; const segmentSampleCount = audioForSegments[0][0].length; const windowWeights = getWindowWeights('hann', segmentSampleCount); const sumOfWeightsForSample = new Float32Array(sampleCount); for (let segmentIndex = 0; segmentIndex < segmentCount; segmentIndex++) { const segmentStartFrameIndex = segmentIndex * segmentHopSize; const segmentStartSampleIndex = segmentStartFrameIndex * fftHopSize; const segmentSamples = audioForSegments[segmentIndex]; for (let segmentSampleOffset = 0; segmentSampleOffset < segmentSampleCount; segmentSampleOffset++) { const sampleIndex = segmentStartSampleIndex + segmentSampleOffset; if (sampleIndex >= sampleCount) { break; } const weight = windowWeights[segmentSampleOffset]; for (let channelIndex = 0; channelIndex < 2; channelIndex++) { joinedSegments[channelIndex][sampleIndex] += segmentSamples[channelIndex][segmentSampleOffset] * weight; } sumOfWeightsForSample[sampleIndex] += weight; } } for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { for (let channelIndex = 0; channelIndex < 2; channelIndex++) { joinedSegments[channelIndex][sampleIndex] /= sumOfWeightsForSample[sampleIndex] + 1e-8; } } } const isolatedRawAudio = { audioChannels: joinedSegments, sampleRate }; logger.end(); return isolatedRawAudio; } async initializeSessionIfNeeded() { if (this.session) { return; } const Onnx = await import('onnxruntime-node'); const executionProviders = this.options.provider ? [this.options.provider] : getDefaultMDXNetProviders(); this.onnxSessionOptions = getOnnxSessionOptions({ executionProviders }); this.session = await Onnx.InferenceSession.create(this.modelFilePath, this.onnxSessionOptions); } } export function getDefaultMDXNetProviders() { if (dmlProviderAvailable()) { return ['dml', 'cpu']; } else { return []; } } export function getProfileForMDXNetModelName(modelName) { if (['UVR_MDXNET_1_9703', 'UVR_MDXNET_2_9682', 'UVR_MDXNET_3_9662', 'UVR_MDXNET_KARA'].includes(modelName)) { return mdxNetModelProfile1; } if (['UVR_MDXNET_Main', 'Kim_Vocal_1', 'Kim_Vocal_2'].includes(modelName)) { return mdxNetModelProfile2; } throw new Error(`Unsupported model name: '${modelName}'`); } export const mdxNetModelProfile1 = { sampleRate: 44100, fftSize: 6144, fftWindowSize: 6144, fftHopSize: 1024, fftWindowType: 'hann', binCount: 2048, segmentSize: 256, segmentHopSize: 224, }; export const mdxNetModelProfile2 = { sampleRate: 44100, fftSize: 7680, fftWindowSize: 7680, fftHopSize: 1024, fftWindowType: 'hann', binCount: 3072, segmentSize: 256, segmentHopSize: 224, }; export const defaultMDXNetOptions = { model: 'UVR_MDXNET_1_9703', provider: undefined, }; //# sourceMappingURL=MDXNetSourceSeparation.js.map