UNPKG

kitten-tts-webgpu

Version:

Run Kitten TTS (80M) locally in the browser via WebGPU. One function call: textToSpeech('Hello!') → WAV blob.

320 lines (319 loc) 14.5 kB
/** * WebGPU inference engine for Kitten TTS V0.8 (80M). * * Runs the full TTS pipeline on the GPU: * phoneme_ids + style + speed → waveform (24kHz audio) */ import { type KittenConfig } from './types.js'; export declare class KittenTTSEngine { private device; private weights; /** Aliases from canonical (mini) onnx:: weight names to actual weight names in the loaded model. */ private weightAliases; private pipelines; private voices; /** LSTM hidden size for text encoder / predictor / duration / shared LSTMs (mini=256, nano=64). */ private lstmHidden; /** Bidirectional LSTM output size = 2 * lstmHidden (mini=512, nano=128). */ private lstmBidir; /** Text encoder embedding / CNN channel dim (mini=512, nano varies). Detected from text_encoder embedding weight. */ private textEncChannels; /** Style embedding total dimension from voices.npz (typically 256 for all model sizes). */ private styleDim; /** Style predictor half-dim = styleDim / 2 (typically 128). */ private styleHalf; /** LSTM input size for predictor/duration/shared = lstmBidir + styleHalf. */ private lstmInputSize; /** BERT embedding dim (mini=128, nano may differ). */ private bertEmbedDim; /** BERT hidden size (mini=768, nano may differ). */ private bertHiddenSize; /** BERT number of attention heads (mini=12). */ private bertNumHeads; /** BERT head dim = bertHiddenSize / bertNumHeads. */ private bertHeadDim; /** BERT FFN intermediate dim (mini=2048). */ private bertFfnDim; /** BERT number of layer iterations (mini=12). */ private bertNumLayers; /** Number of predictor LSTM+FC pairs (mini=3, nano=2). */ private numPredLstmPairs; /** BERT encoder projection output dim (= lstmBidir, mini=512). */ private bertProjDim; /** Number of text encoder CNN blocks (mini=3, nano=2). */ private numTextEncCnnBlocks; /** Decoder encode output channels (mini=1024, nano=256). Detected from weight shapes. */ private decEncodeOutCh; /** Decoder decode.0-2 output channels (mini=1024, nano=256). */ private decDecodeOutCh; /** Decoder decode.3 output channels (mini=512, nano=256). */ private decDecode3OutCh; /** HiFi-GAN ups.0 output channels (mini=256, nano=128). Detected from weight shapes. */ private hifiUps0OutCh; /** HiFi-GAN ups.1 output channels (mini=128, nano=64). */ private hifiUps1OutCh; /** N/F0 predictor block0 output channels (mini=512, nano=128). */ private predBlock0OutCh; /** N/F0 predictor block1+ output channels (mini=256, nano=64). */ private predBlock1OutCh; private config; /** Uniform buffers created during dispatch, cleaned up after submit. */ private pendingUniformBuffers; /** CPU-side weight cache for re-uploading after freeGpuWeights(). * Populated during loadModel() so we can free/re-upload GPU buffers * between generations to prevent iOS Safari jetsam kills. */ private weightCache; /** Pending command buffers for batch submission. */ private pendingCommandBuffers; /** Shared command encoder for batching dispatches (reduces iOS Safari crashes). * Dispatches are recorded into this encoder and only submitted at readBuffer * boundaries or when flushSharedEncoder() is called explicitly. */ private sharedEncoder; /** Buffers to destroy after the shared encoder is submitted. */ private deferredDestroys; /** Buffer pool: reuse GPU buffers by byte size instead of destroy+reallocate. * Key insight: reusing buffers avoids Metal accumulating dead references to * destroyed buffers, which is the root cause of iOS jetsam kills. */ private bufferPool; /** Buffers to return to pool (not destroy) after the next shared encoder flush. */ private deferredPoolReturns; /** Cached CPU copies of sin generator weights (avoid readBuffer every inference). */ private sinGenWeights; /** Debug mode: when true, intermediate activations are captured for comparison. */ debugCapture: boolean; /** Captured activations (name → {data, shape}). Only populated when debugCapture=true. */ debugActivations: Map<string, { data: Float32Array; shape: number[]; }>; private debugBertBuffers; /** Performance profiling: when true, logs timing per pipeline stage. */ profile: boolean; private timings; private _stageStart; constructor(config?: KittenConfig); /** Start timing a pipeline stage. Call endStage() to record. */ private startStage; /** End timing and record the stage duration (includes GPU sync). * ALWAYS flushes batched dispatches + deferred destroys to keep peak GPU * memory low (prevents iOS Safari jetsam kills). */ private endStage; /** Last timing report from generate(), available after each call. */ lastTimings: { name: string; ms: number; }[]; /** Print timing summary to console and store for external access. */ private printTimings; /** Capture a GPU buffer's contents as a named debug activation. No-op when debugCapture is off. */ private captureDebug; /** Initialize WebGPU device and compile shaders. */ init(): Promise<void>; /** Load model weights from ONNX file and voices from NPZ. */ loadModel(onnxUrl: string, voicesUrl: string): Promise<void>; /** * Build aliases from canonical (mini) onnx:: weight names to actual names. * * Different model sizes (mini, nano) use different numeric IDs for onnx:: weights. * MatMul weights are always 10 in the same order, so simple positional mapping works. * * LSTM weights vary: mini has 6 LSTMs (18 weights), nano has 5 LSTMs (15 weights). * We map by semantic group rather than simple position: * Group 0: text encoder LSTM (3 weights) * Group 1..N: predictor LSTMs (3 weights each, N varies by model) * Group N+1: duration LSTM (3 weights) * Group N+2: shared LSTM (3 weights) */ private buildOnnxAliases; /** * Detect model dimensions from loaded weight shapes. * This allows the engine to work with any model size (mini, nano, etc.) * without hardcoding dimensions. */ private detectDimensions; /** Try to get a weight by name, checking aliases. Returns null if not found. */ private tryGetWeight; /** Run TTS inference. */ generate(inputIds: number[], voice?: string, speed?: number, textLength?: number, // Raw text character count for voice style selection onProgress?: (stage: string) => void): Promise<{ waveform: Float32Array; duration: Int32Array; }>; /** CPU LSTM implementation for debugging/verification. */ private cpuLSTM; private createBuffer; private createEmptyBuffer; private createUniformBuffer; /** Destroy all pending uniform buffers (call after command submit). */ private flushUniformBuffers; /** Get a weight tensor, throwing a descriptive error if missing. * Handles cross-model weight name differences: micro/nano models may store * some weights with/without `_quantized` suffix depending on their quantization. * Also checks onnx:: weight aliases for cross-model-size compatibility. */ private requireWeight; /** Start batching — flush any pending encoder + uniform buffers. */ private beginBatch; /** End batch — flush encoder + uniform buffers. */ private endBatch; /** Flush batch (alias). */ private flushBatch; /** Submit batch — flush encoder + uniform buffers. */ private submitBatch; /** * Execute a dispatch on a pipeline. Dispatches are batched into a shared * command encoder to reduce iOS Safari crashes (177 submits → ~5). * The shared encoder is flushed at readBuffer boundaries. */ private dispatchSingle; /** Queue a buffer for destruction after the next shared encoder flush. * Use this instead of buffer.destroy() when batching dispatches. */ private deferDestroy; /** Flush the shared encoder — submit all batched dispatches and destroy deferred buffers. */ private flushSharedEncoder; /** Flush uniform buffers and shared encoder. */ private flushBatchEncoder; /** Get a buffer from the pool (or allocate new) — same size buffers are reused. */ private poolGet; /** Queue a buffer to return to pool after the next shared encoder flush. * Unlike deferDestroy, the buffer is kept alive for reuse. */ private poolReturn; /** Destroy all pooled buffers (call when pool is no longer needed). */ private destroyPool; /** Read buffer contents back to CPU. */ private readBuffer; private dispatchEmbedding; private dispatchAdd; /** * Run the BERT/ALBERT encoder. * * Pipeline: * 1. Embedding sum (word + position) is already computed → [seqLen, 128] * 2. LayerNorm on embeddings (128-dim) * 3. Linear projection 128 → 768 (embedding_hidden_mapping_in) * 4. 12 iterations of shared ALBERT layer: * a. Self-attention (Q/K/V projections, scaled dot-product, output projection) * b. Residual + LayerNorm (attention) * c. FFN: Linear 768→2048 (GELU) → Linear 2048→768 * d. Residual + LayerNorm (full layer) * 5. Return [1, seqLen, 768] */ private runBertEncoder; /** * Run the text encoder. * * Pipeline: * 1. Embedding lookup: input_ids → [seqLen, textEncChannels] * 2. Transpose to channels-first: [textEncChannels, seqLen] * 3. N× Conv1d(C, C, k=5, pad=2) + LayerNorm + LeakyReLU (N=numTextEncCnnBlocks) * 4. Transpose to [seqLen, textEncChannels] for LSTM input * 5. Bidirectional LSTM (hidden=lstmHidden) → [seqLen, 2, lstmHidden] * * Returns buffer with shape [seqLen, 2, lstmHidden] (= [seqLen, lstmBidir] flattened) */ private runTextEncoder; private dispatchLayerNorm; private dispatchMatmul; private dispatchGelu; private dispatchMatmulGelu; private dispatchMHA; private dispatchAddPass; private dispatchConv1d; private dispatchTranspose; private dispatchLeakyRelu; private dispatchSigmoid; private dispatchLSTM; /** * Run one AdaIN ResNet block. * * Pattern: * norm1: InstanceNorm(x) → AdaIN(FC(style)) → LeakyReLU * conv1: Conv1d(in→out, k=3, pad=1) * norm2: InstanceNorm → AdaIN(FC(style)) → LeakyReLU * conv2: Conv1d(out→out, k=3, pad=1) * residual: x + conv2_out (with optional conv1x1 for channel change) * * Some blocks also have sigmoid → divide by 2 (block 0 in N/F0 predictor). */ private runAdaINResNetBlock; /** * Run one decoder AdaIN conv block. * * Pattern: * norm1: InstanceNorm(input) + AdaIN(FC(style)) * conv1: Conv1d(in→out, k=3, pad=1) + LeakyReLU * norm2: InstanceNorm + AdaIN(FC(style)) * conv2: Conv1d(out→out, k=3, pad=1) + LeakyReLU * residual: conv1x1(in→out) + output */ private runDecoderBlock; private buildDecodeInput; /** * Run one HiFi-GAN residual block with AdaIN and Snake activation. * * ONNX pattern per iteration i: * InstanceNorm(input) → AdaIN1(style) → Snake(alpha1) → conv1(dilated) * → InstanceNorm → AdaIN2(style) → Snake(alpha2) → conv2 * → Add(conv2, input) [simple residual, no alpha scaling] * * Snake activation: x + (1/alpha) * sin²(alpha * x) * alpha1/alpha2 are Snake parameters [1, C, 1], NOT residual weights. */ private runHiFiGANResBlock; private dispatchSnake; private dispatchInstanceNorm; private dispatchAdaIN; private dispatchConvTranspose1d; /** * Depthwise ConvTranspose1d: each channel processed independently (groups=channels). * Used for pool layers in N.1, F0.1, decode.3 that double temporal resolution. * Weight shape: [channels, 1, kernel_size] stored as [channels * kernel_size]. */ private dispatchDepthwiseConvTranspose1d; /** * Resize 1D: nearest-neighbor interpolation (typically 2x upsampling). * Used on the residual path of blocks with pools. */ private dispatchResize1d; /** Multiply all elements by a constant scale factor. */ private dispatchScale; /** Concatenate two channel-first tensors along channel dimension. */ private dispatchConcatChannels; /** AdaIN row-major: normed[rows, C] + style_fc[2*C] → output[rows, C]. */ private dispatchAdaINRowMajor; /** Concat row-major A[rows, colsA] with broadcast B[colsB] → [rows, colsA+colsB]. */ private dispatchConcatBroadcast; /** GPU length expansion: [seqLen, D] → [totalFrames, D] using duration cumsum. */ private dispatchExpandRowMajor; /** GPU length expansion with transpose: [seqLen, D] row-major → [D, totalFrames] channel-first. */ private dispatchExpandChannelFirst; /** GPU iSTFT synthesis: conv_post [22, genLen] → waveform [waveformLen]. */ private dispatchISTFT; /** Reflection pad 1D along time dimension. */ private dispatchReflectionPad1d; /** Alpha-weighted residual: output = current + alpha[ch] * residual */ private dispatchAlphaResidual; private compileShaders; /** Infer bind group layout from WGSL source by parsing @binding annotations. */ private inferBindGroupLayout; /** * Generate source excitation signal for noise injection. * * Pipeline (computed on CPU — cumulative sum is sequential): * 1. Read F0_proj from GPU [1, f0Length] * 2. Upsample F0 to waveform rate via nearest-neighbor [waveLen] * 3. For 9 harmonics: cumsum(k * f0 / sr) → sin(2π * phase) * Voiced (f0 > 0): scaled sin wave. Unvoiced: Gaussian noise. * 4. Linear(9→1) + bias + tanh → [waveLen, 1] * 5. Edge-pad 10 → [waveLen + 20] * 6. Forward STFT via Conv1d(1→11, k=20, s=5) for real + imag → magnitude + phase * 7. Concat [magnitude, phase] → [22, stftLen] * 8. Upload to GPU */ private generateSourceExcitation; /** Check if model weights are loaded (GPU buffers alive). */ get weightsLoaded(): boolean; /** Check if CPU weight cache is available for re-upload. */ get hasCachedWeights(): boolean; /** Destroy all GPU resources. */ destroy(): void; }