echogarden
Version:
An easy-to-use speech toolset. Includes tools for synthesis, recognition, alignment, speech translation, language detection, source separation and more.
379 lines (274 loc) • 11.3 kB
text/typescript
import type * as Onnx from 'onnxruntime-node'
import { getEmptyRawAudio, RawAudio } from '../audio/AudioUtilities.js'
import { getWindowWeights, createStftrGenerator, stiftr, WindowType } from '../dsp/FFT.js'
import { logToStderr } from '../utilities/Utilities.js'
import { Logger } from '../utilities/Logger.js'
import { OnnxExecutionProvider, dmlProviderAvailable, getOnnxSessionOptions } from '../utilities/OnnxUtilities.js'
import chalk from 'chalk'
import { WindowedList } from '../data-structures/WindowedList.js'
import { logLevelGreaterOrEqualTo } from '../api/API.js'
const log = logToStderr
export async function isolate(
rawAudio: RawAudio,
modelFilePath: string,
modelProfile: MDXNetModelProfile,
options: MDXNetOptions) {
const model = new MDXNet(modelFilePath, modelProfile, options)
return model.processAudio(rawAudio)
}
export class MDXNet {
session?: Onnx.InferenceSession
onnxSessionOptions?: Onnx.InferenceSession.SessionOptions
constructor(
public readonly modelFilePath: string,
public readonly modelProfile: MDXNetModelProfile,
public readonly options: MDXNetOptions) {
}
async processAudio(rawAudio: RawAudio) {
if (rawAudio.audioChannels.length !== 2) {
throw new Error(`Input audio must be stereo`)
}
if (rawAudio.sampleRate !== this.modelProfile.sampleRate) {
throw new Error(`Input audio must have a sample rate of ${this.modelProfile.sampleRate} Hz`)
}
if (rawAudio.audioChannels[0].length === 0) {
return getEmptyRawAudio(rawAudio.audioChannels.length, rawAudio.sampleRate)
}
const enableTraceLogging = logLevelGreaterOrEqualTo('trace')
const logger = new Logger()
await logger.startAsync(`Initialize session for MDX-NET model '${this.options.model!}'`)
await this.initializeSessionIfNeeded()
logger.end()
logger.logTitledMessage(`Using ONNX execution provider`, `${this.onnxSessionOptions!.executionProviders!.join(', ')}`)
const Onnx = await import('onnxruntime-node')
const sampleRate = this.modelProfile.sampleRate
const fftSize = this.modelProfile.fftSize
const fftWindowSize = this.modelProfile.fftWindowSize
const fftHopSize = this.modelProfile.fftHopSize
const fftWindowType = this.modelProfile.fftWindowType
const binCount = this.modelProfile.binCount
const segmentSize = this.modelProfile.segmentSize
const segmentHopSize = this.modelProfile.segmentHopSize
const sampleCount = rawAudio.audioChannels[0].length
const fftSizeReciprocal = 1 / fftSize
// Initialize generators for STFT frames for each channel
const fftFramesLeftGenerator = await createStftrGenerator(rawAudio.audioChannels[0], fftSize, fftWindowSize, fftHopSize, fftWindowType)
const fftFramesRightGenerator = await createStftrGenerator(rawAudio.audioChannels[1], fftSize, fftWindowSize, fftHopSize, fftWindowType)
// Initial windowed lists to store recently computed STFT frames
const fftFramesLeftWindowedList = new WindowedList<Float32Array>(segmentSize)
const fftFramesRightWindowedList = new WindowedList<Float32Array>(segmentSize)
const audioForSegments: Float32Array[][] = []
for (let segmentStartFrameOffset = 0; ; segmentStartFrameOffset += segmentHopSize) {
const segmentEndFrameOffset = segmentStartFrameOffset + segmentSize
const timePosition = segmentStartFrameOffset * (fftHopSize / sampleRate)
if (enableTraceLogging) {
await logger.startAsync(`Compute STFT of segment at time position ${timePosition.toFixed(2)}`, undefined, chalk.magentaBright)
} else {
await logger.startAsync(`Process segment at time position ${timePosition.toFixed(2)}`)
}
while (fftFramesLeftWindowedList.endOffset < segmentEndFrameOffset) {
const nextLeftFrameResult = fftFramesLeftGenerator.next()
if (nextLeftFrameResult.done) {
break
}
const nextRightFrameResult = fftFramesRightGenerator.next()
if (nextRightFrameResult.done) {
break
}
fftFramesLeftWindowedList.add(nextLeftFrameResult.value)
fftFramesRightWindowedList.add(nextRightFrameResult.value)
}
const fftFramesForSegment = [
fftFramesLeftWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset),
fftFramesRightWindowedList.slice(segmentStartFrameOffset, segmentEndFrameOffset)
]
const segmentLength = fftFramesForSegment[0].length
const isLastSegment = segmentLength < segmentSize
if (enableTraceLogging) {
await logger.startAsync(`Reshape STFT frames`)
}
const flattenedInputTensor = new Float32Array(1 * 4 * binCount * segmentSize)
{
let writePosition = 0
// 4 tensor elements are structured as:
// <Channel 0 real> <Channel 0 imaginary> <Channel 1 real> <Channel 1 imaginary>
for (let tensorElementIndex = 0; tensorElementIndex < 4; tensorElementIndex++) {
const isRealComponentTensorElementIndex = tensorElementIndex % 2 === 0
const audioChannelIndex = tensorElementIndex < 2 ? 0 : 1
for (let binIndex = 0; binIndex < binCount; binIndex++) {
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
let value = 0
if (frameIndex < segmentLength) {
const frame = fftFramesForSegment[audioChannelIndex][frameIndex]
if (isRealComponentTensorElementIndex) {
value = frame[binIndex << 1]
} else {
value = frame[(binIndex << 1) + 1]
}
}
flattenedInputTensor[writePosition++] = value
}
}
}
}
if (enableTraceLogging) {
await logger.startAsync(`Process segment with MDXNet model`)
}
const inputTensor = new Onnx.Tensor('float32', flattenedInputTensor, [1, 4, binCount, segmentSize])
const { output: outputTensor } = await this.session!.run({ input: inputTensor })
if (enableTraceLogging) {
await logger.startAsync('Reshape processed frames')
}
const flattenedOutputTensor = outputTensor.data as Float32Array
const outputSegmentFramesForChannel: Float32Array[][] = []
{
for (let outChannelIndex = 0; outChannelIndex < 2; outChannelIndex++) {
const framesForChannel: Float32Array[] = []
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
const frame = new Float32Array(fftSize)
framesForChannel.push(frame)
}
outputSegmentFramesForChannel.push(framesForChannel)
}
let readPosition = 0
for (let tensorChannelIndex = 0; tensorChannelIndex < 4; tensorChannelIndex++) {
const isRealTensorChannelIndex = tensorChannelIndex % 2 === 0
const audioChannelIndex = tensorChannelIndex < 2 ? 0 : 1
const framesForOutputChannel = outputSegmentFramesForChannel[audioChannelIndex]
for (let binIndex = 0; binIndex < binCount; binIndex++) {
for (let frameIndex = 0; frameIndex < segmentSize; frameIndex++) {
const outFrame = framesForOutputChannel[frameIndex]
if (isRealTensorChannelIndex) {
outFrame[binIndex << 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal
} else {
outFrame[(binIndex << 1) + 1] = flattenedOutputTensor[readPosition++] * fftSizeReciprocal
}
}
}
}
}
const outputAudioChannels: Float32Array[] = []
if (enableTraceLogging) {
await logger.startAsync(`Compute inverse STFT of model output for segment`)
}
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
const samples = await stiftr(
outputSegmentFramesForChannel[channelIndex],
fftSize,
fftWindowSize,
fftHopSize,
fftWindowType)
outputAudioChannels.push(samples)
}
audioForSegments.push(outputAudioChannels)
if (isLastSegment) {
break
}
}
// Join segments using overlapping Hann windows
await logger.startAsync(`Join segments`)
const joinedSegments = [new Float32Array(sampleCount), new Float32Array(sampleCount)]
{
const segmentCount = audioForSegments.length
const segmentSampleCount = audioForSegments[0][0].length
const windowWeights = getWindowWeights('hann', segmentSampleCount)
const sumOfWeightsForSample = new Float32Array(sampleCount)
for (let segmentIndex = 0; segmentIndex < segmentCount; segmentIndex++) {
const segmentStartFrameIndex = segmentIndex * segmentHopSize
const segmentStartSampleIndex = segmentStartFrameIndex * fftHopSize
const segmentSamples = audioForSegments[segmentIndex]
for (let segmentSampleOffset = 0; segmentSampleOffset < segmentSampleCount; segmentSampleOffset++) {
const sampleIndex = segmentStartSampleIndex + segmentSampleOffset
if (sampleIndex >= sampleCount) {
break
}
const weight = windowWeights[segmentSampleOffset]
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
joinedSegments[channelIndex][sampleIndex] += segmentSamples[channelIndex][segmentSampleOffset] * weight
}
sumOfWeightsForSample[sampleIndex] += weight
}
}
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) {
for (let channelIndex = 0; channelIndex < 2; channelIndex++) {
joinedSegments[channelIndex][sampleIndex] /= sumOfWeightsForSample[sampleIndex] + 1e-8
}
}
}
const isolatedRawAudio: RawAudio = { audioChannels: joinedSegments, sampleRate }
logger.end()
return isolatedRawAudio
}
private async initializeSessionIfNeeded() {
if (this.session) {
return
}
const Onnx = await import('onnxruntime-node')
const executionProviders: OnnxExecutionProvider[] =
this.options.provider ? [this.options.provider] : getDefaultMDXNetProviders()
this.onnxSessionOptions = getOnnxSessionOptions({ executionProviders })
this.session = await Onnx.InferenceSession.create(this.modelFilePath, this.onnxSessionOptions)
}
}
export function getDefaultMDXNetProviders(): OnnxExecutionProvider[] {
if (dmlProviderAvailable()) {
return ['dml', 'cpu']
} else {
return []
}
}
export function getProfileForMDXNetModelName(modelName: MDXNetModelName) {
if (['UVR_MDXNET_1_9703', 'UVR_MDXNET_2_9682', 'UVR_MDXNET_3_9662', 'UVR_MDXNET_KARA'].includes(modelName)) {
return mdxNetModelProfile1
}
if (['UVR_MDXNET_Main', 'Kim_Vocal_1', 'Kim_Vocal_2'].includes(modelName)) {
return mdxNetModelProfile2
}
throw new Error(`Unsupported model name: '${modelName}'`)
}
export const mdxNetModelProfile1: MDXNetModelProfile = {
sampleRate: 44100,
fftSize: 6144,
fftWindowSize: 6144,
fftHopSize: 1024,
fftWindowType: 'hann',
binCount: 2048,
segmentSize: 256,
segmentHopSize: 224,
}
export const mdxNetModelProfile2: MDXNetModelProfile = {
sampleRate: 44100,
fftSize: 7680,
fftWindowSize: 7680,
fftHopSize: 1024,
fftWindowType: 'hann',
binCount: 3072,
segmentSize: 256,
segmentHopSize: 224,
}
export interface MDXNetModelProfile {
sampleRate: number
fftSize: number
fftWindowSize: number
fftHopSize: number
fftWindowType: WindowType
binCount: number
segmentSize: number
segmentHopSize: number
}
export type MDXNetModelName =
'UVR_MDXNET_1_9703' |
'UVR_MDXNET_2_9682' |
'UVR_MDXNET_3_9662' |
'UVR_MDXNET_KARA' |
'UVR_MDXNET_Main' |
'Kim_Vocal_1' |
'Kim_Vocal_2'
export interface MDXNetOptions {
model?: MDXNetModelName
provider?: OnnxExecutionProvider
}
export const defaultMDXNetOptions: MDXNetOptions = {
model: 'UVR_MDXNET_1_9703',
provider: undefined,
}