UNPKG

echogarden

Version:

An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.

379 lines (274 loc) 11.3 kB
import type * as Onnx from 'onnxruntime-node' import { getEmptyRawAudio, RawAudio } from '../audio/AudioUtilities.js' import { getWindowWeights, createStftrGenerator, stiftr, WindowType } from '../dsp/FFT.js' import { logToStderr } from '../utilities/Utilities.js' import { Logger } from '../utilities/Logger.js' import { OnnxExecutionProvider, dmlProviderAvailable, getOnnxSessionOptions } from '../utilities/OnnxUtilities.js' import chalk from 'chalk' import { WindowedList } from '../data-structures/WindowedList.js' import { logLevelGreaterOrEqualTo } from '../api/API.js' const log = logToStderr export async function isolate( rawAudio: RawAudio, modelFilePath: string, modelProfile: MDXNetModelProfile, options: MDXNetOptions) { const model = new MDXNet(modelFilePath, modelProfile, options) return model.processAudio(rawAudio) } export class MDXNet { session?: Onnx.InferenceSession onnxSessionOptions?: Onnx.InferenceSession.SessionOptions constructor( public readonly modelFilePath: string, public readonly modelProfile: MDXNetModelProfile, public readonly options: MDXNetOptions) { } async processAudio(rawAudio: RawAudio) { if (rawAudio.audioChannels.length !== 2) { throw new Error(`Input audio must be stereo`) } if (rawAudio.sampleRate !== this.modelProfile.sampleRate) { throw new Error(`Input audio must have a sample rate of ${this.modelProfile.sampleRate} Hz`) } if (rawAudio.audioChannels[0].length === 0) { return getEmptyRawAudio(rawAudio.audioChannels.length, rawAudio.sampleRate) } const enableTraceLogging = logLevelGreaterOrEqualTo('trace') const logger = new Logger() await logger.startAsync(`Initialize session for MDX-NET model '${this.options.model!}'`) await this.initializeSessionIfNeeded() logger.end() logger.logTitledMessage(`Using ONNX execution provider`, `${this.onnxSessionOptions!.executionProviders!.join(', ')}`) const Onnx = await import('onnxruntime-node') const sampleRate = this.modelProfile.sampleRate const fftSize = this.modelProfile.fftSize const fftWindowSize = this.modelProfile.fftWindowSize const fftHopSize = this.modelProfile.fftHopSize const fftWindowType = this.modelProfile.fftWindowType const binCount = this.modelProfile.binCount const segmentSize = this.modelProfile.segmentSize const segmentHopSize = this.modelProfile.segmentHopSize const sampleCount = rawAudio.audioChannels[0].length const fftSizeReciprocal = 1 / fftSize // Initialize generators for STFT frames for each channel const fftFramesLeftGenerator = await createStftrGenerator(rawAudio.audioChannels[0], fftSize, fftWindowSize, fftHopSize, fftWindowType) const fftFramesRightGenerator = await createStftrGenerator(rawAudio.audioChannels[1], fftSize, fftWindowSize, fftHopSize, fftWindowType) // Initial windowed lists to store recently computed STFT frames const fftFramesLeftWindowedList = new WindowedList<Float32Array>(segmentSize) const fftFramesRightWindowedList = new WindowedList<Float32Array>(segmentSize) const audioForSegments: Float32Array[][] = [] for (let segmentStartFrameOffset = 0; ; segmentStartFrameOffset += segmentHopSize) { const segmentEndFrameOffset = segmentStartFrameOffset + segmentSize const timePosition = segmentStartFrameOffset * (fftHopSize / sampleRate) if (enableTraceLogging) { await logger.startAsync(`Compute STFT of segment at time position ${timePosition.toFixed(2)}`, undefined, chalk.magentaBright) } else { await logger.startAsync(`Process segment at time position ${timePosition.toFixed(2)}`) } while (fftFramesLeftWindowedList.endOffset < segmentEndFrameOffset) { const nextLeftFrameResult = fftFramesLeftGenerator.next() if (nextLeftFrameResult.done) { break } const nextRightFrameResult = fftFramesRightGenerator.next() if (nextRightFrameResult.done) { break } fftFramesLeftWindowedList.add(nextLeftFrameResult.value) fftFramesRightWindowedList.add(nextRightFrameResult.value) } const fftFramesForSegment = [ fftFramesLeftWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset), fftFramesRightWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset) ] const segmentLength = fftFramesForSegment[0].length const isLastSegment = segmentLength < segmentSize if (enableTraceLogging) { await logger.startAsync(`Reshape STFT frames`) } const flattenedInputTensor = new Float32Array(1 * 4 * binCount * segmentSize) { let writePosition = 0 // 4 tensor elements are structured as: // <Channel 0 real> <Channel 0 imaginary> <Channel 1 real> <Channel 1 imaginary> for (let tensorElementIndex = 0; tensorElementIndex < 4; tensorElementIndex++) { const isRealComponentTensorElementIndex = tensorElementIndex % 2 === 0 const audioChannelIndex = tensorElementIndex < 2 ? 0 : 1 for (let binIndex = 0; binIndex < binCount; binIndex++) { for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { let value = 0 if (frameIndex < segmentLength) { const frame = fftFramesForSegment[audioChannelIndex][frameIndex] if (isRealComponentTensorElementIndex) { value = frame[binIndex << 1] } else { value = frame[(binIndex << 1) + 1] } } flattenedInputTensor[writePosition++] = value } } } } if (enableTraceLogging) { await logger.startAsync(`Process segment with MDXNet model`) } const inputTensor = new Onnx.Tensor('float32', flattenedInputTensor, [1, 4, binCount, segmentSize]) const { output: outputTensor } = await this.session!.run({ input: inputTensor }) if (enableTraceLogging) { await logger.startAsync('Reshape processed frames') } const flattenedOutputTensor = outputTensor.data as Float32Array const outputSegmentFramesForChannel: Float32Array[][] = [] { for (let outChannelIndex = 0; outChannelIndex < 2; outChannelIndex++) { const framesForChannel: Float32Array[] = [] for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { const frame = new Float32Array(fftSize) framesForChannel.push(frame) } outputSegmentFramesForChannel.push(framesForChannel) } let readPosition = 0 for (let tensorChannelIndex = 0; tensorChannelIndex < 4; tensorChannelIndex++) { const isRealTensorChannelIndex = tensorChannelIndex % 2 === 0 const audioChannelIndex = tensorChannelIndex < 2 ? 0 : 1 const framesForOutputChannel = outputSegmentFramesForChannel[audioChannelIndex] for (let binIndex = 0; binIndex < binCount; binIndex++) { for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) { const outFrame = framesForOutputChannel[frameIndex] if (isRealTensorChannelIndex) { outFrame[binIndex << 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal } else { outFrame[(binIndex << 1) + 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal } } } } } const outputAudioChannels: Float32Array[] = [] if (enableTraceLogging) { await logger.startAsync(`Compute inverse STFT of model output for segment`) } for (let channelIndex = 0; channelIndex < 2; channelIndex++) { const samples = await stiftr( outputSegmentFramesForChannel[channelIndex], fftSize, fftWindowSize, fftHopSize, fftWindowType) outputAudioChannels.push(samples) } audioForSegments.push(outputAudioChannels) if (isLastSegment) { break } } // Join segments using overlapping Hann windows await logger.startAsync(`Join segments`) const joinedSegments = [new Float32Array(sampleCount), new Float32Array(sampleCount)] { const segmentCount = audioForSegments.length const segmentSampleCount = audioForSegments[0][0].length const windowWeights = getWindowWeights('hann', segmentSampleCount) const sumOfWeightsForSample = new Float32Array(sampleCount) for (let segmentIndex = 0; segmentIndex < segmentCount; segmentIndex++) { const segmentStartFrameIndex = segmentIndex * segmentHopSize const segmentStartSampleIndex = segmentStartFrameIndex * fftHopSize const segmentSamples = audioForSegments[segmentIndex] for (let segmentSampleOffset = 0; segmentSampleOffset < segmentSampleCount; segmentSampleOffset++) { const sampleIndex = segmentStartSampleIndex + segmentSampleOffset if (sampleIndex >= sampleCount) { break } const weight = windowWeights[segmentSampleOffset] for (let channelIndex = 0; channelIndex < 2; channelIndex++) { joinedSegments[channelIndex][sampleIndex] += segmentSamples[channelIndex][segmentSampleOffset] * weight } sumOfWeightsForSample[sampleIndex] += weight } } for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { for (let channelIndex = 0; channelIndex < 2; channelIndex++) { joinedSegments[channelIndex][sampleIndex] /= sumOfWeightsForSample[sampleIndex] + 1e-8 } } } const isolatedRawAudio: RawAudio = { audioChannels: joinedSegments, sampleRate } logger.end() return isolatedRawAudio } private async initializeSessionIfNeeded() { if (this.session) { return } const Onnx = await import('onnxruntime-node') const executionProviders: OnnxExecutionProvider[] = this.options.provider ? [this.options.provider] : getDefaultMDXNetProviders() this.onnxSessionOptions = getOnnxSessionOptions({ executionProviders }) this.session = await Onnx.InferenceSession.create(this.modelFilePath, this.onnxSessionOptions) } } export function getDefaultMDXNetProviders(): OnnxExecutionProvider[] { if (dmlProviderAvailable()) { return ['dml', 'cpu'] } else { return [] } } export function getProfileForMDXNetModelName(modelName: MDXNetModelName) { if (['UVR_MDXNET_1_9703', 'UVR_MDXNET_2_9682', 'UVR_MDXNET_3_9662', 'UVR_MDXNET_KARA'].includes(modelName)) { return mdxNetModelProfile1 } if (['UVR_MDXNET_Main', 'Kim_Vocal_1', 'Kim_Vocal_2'].includes(modelName)) { return mdxNetModelProfile2 } throw new Error(`Unsupported model name: '${modelName}'`) } export const mdxNetModelProfile1: MDXNetModelProfile = { sampleRate: 44100, fftSize: 6144, fftWindowSize: 6144, fftHopSize: 1024, fftWindowType: 'hann', binCount: 2048, segmentSize: 256, segmentHopSize: 224, } export const mdxNetModelProfile2: MDXNetModelProfile = { sampleRate: 44100, fftSize: 7680, fftWindowSize: 7680, fftHopSize: 1024, fftWindowType: 'hann', binCount: 3072, segmentSize: 256, segmentHopSize: 224, } export interface MDXNetModelProfile { sampleRate: number fftSize: number fftWindowSize: number fftHopSize: number fftWindowType: WindowType binCount: number segmentSize: number segmentHopSize: number } export type MDXNetModelName = 'UVR_MDXNET_1_9703' | 'UVR_MDXNET_2_9682' | 'UVR_MDXNET_3_9662' | 'UVR_MDXNET_KARA' | 'UVR_MDXNET_Main' | 'Kim_Vocal_1' | 'Kim_Vocal_2' export interface MDXNetOptions { model?: MDXNetModelName provider?: OnnxExecutionProvider } export const defaultMDXNetOptions: MDXNetOptions = { model: 'UVR_MDXNET_1_9703', provider: undefined, }