kitten-tts-webgpu
Version:
Run Kitten TTS (80M) locally in the browser via WebGPU. One function call: textToSpeech('Hello!') → WAV blob.
320 lines (319 loc) • 14.5 kB
TypeScript
/**
* WebGPU inference engine for Kitten TTS V0.8 (80M).
*
* Runs the full TTS pipeline on the GPU:
* phoneme_ids + style + speed → waveform (24kHz audio)
*/
import { type KittenConfig } from './types.js';
export declare class KittenTTSEngine {
private device;
private weights;
/** Aliases from canonical (mini) onnx:: weight names to actual weight names in the loaded model. */
private weightAliases;
private pipelines;
private voices;
/** LSTM hidden size for text encoder / predictor / duration / shared LSTMs (mini=256, nano=64). */
private lstmHidden;
/** Bidirectional LSTM output size = 2 * lstmHidden (mini=512, nano=128). */
private lstmBidir;
/** Text encoder embedding / CNN channel dim (mini=512, nano varies). Detected from text_encoder embedding weight. */
private textEncChannels;
/** Style embedding total dimension from voices.npz (typically 256 for all model sizes). */
private styleDim;
/** Style predictor half-dim = styleDim / 2 (typically 128). */
private styleHalf;
/** LSTM input size for predictor/duration/shared = lstmBidir + styleHalf. */
private lstmInputSize;
/** BERT embedding dim (mini=128, nano may differ). */
private bertEmbedDim;
/** BERT hidden size (mini=768, nano may differ). */
private bertHiddenSize;
/** BERT number of attention heads (mini=12). */
private bertNumHeads;
/** BERT head dim = bertHiddenSize / bertNumHeads. */
private bertHeadDim;
/** BERT FFN intermediate dim (mini=2048). */
private bertFfnDim;
/** BERT number of layer iterations (mini=12). */
private bertNumLayers;
/** Number of predictor LSTM+FC pairs (mini=3, nano=2). */
private numPredLstmPairs;
/** BERT encoder projection output dim (= lstmBidir, mini=512). */
private bertProjDim;
/** Number of text encoder CNN blocks (mini=3, nano=2). */
private numTextEncCnnBlocks;
/** Decoder encode output channels (mini=1024, nano=256). Detected from weight shapes. */
private decEncodeOutCh;
/** Decoder decode.0-2 output channels (mini=1024, nano=256). */
private decDecodeOutCh;
/** Decoder decode.3 output channels (mini=512, nano=256). */
private decDecode3OutCh;
/** HiFi-GAN ups.0 output channels (mini=256, nano=128). Detected from weight shapes. */
private hifiUps0OutCh;
/** HiFi-GAN ups.1 output channels (mini=128, nano=64). */
private hifiUps1OutCh;
/** N/F0 predictor block0 output channels (mini=512, nano=128). */
private predBlock0OutCh;
/** N/F0 predictor block1+ output channels (mini=256, nano=64). */
private predBlock1OutCh;
private config;
/** Uniform buffers created during dispatch, cleaned up after submit. */
private pendingUniformBuffers;
/** CPU-side weight cache for re-uploading after freeGpuWeights().
* Populated during loadModel() so we can free/re-upload GPU buffers
* between generations to prevent iOS Safari jetsam kills. */
private weightCache;
/** Pending command buffers for batch submission. */
private pendingCommandBuffers;
/** Shared command encoder for batching dispatches (reduces iOS Safari crashes).
* Dispatches are recorded into this encoder and only submitted at readBuffer
* boundaries or when flushSharedEncoder() is called explicitly. */
private sharedEncoder;
/** Buffers to destroy after the shared encoder is submitted. */
private deferredDestroys;
/** Buffer pool: reuse GPU buffers by byte size instead of destroy+reallocate.
* Key insight: reusing buffers avoids Metal accumulating dead references to
* destroyed buffers, which is the root cause of iOS jetsam kills. */
private bufferPool;
/** Buffers to return to pool (not destroy) after the next shared encoder flush. */
private deferredPoolReturns;
/** Cached CPU copies of sin generator weights (avoid readBuffer every inference). */
private sinGenWeights;
/** Debug mode: when true, intermediate activations are captured for comparison. */
debugCapture: boolean;
/** Captured activations (name → {data, shape}). Only populated when debugCapture=true. */
debugActivations: Map<string, {
data: Float32Array;
shape: number[];
}>;
private debugBertBuffers;
/** Performance profiling: when true, logs timing per pipeline stage. */
profile: boolean;
private timings;
private _stageStart;
constructor(config?: KittenConfig);
/** Start timing a pipeline stage. Call endStage() to record. */
private startStage;
/** End timing and record the stage duration (includes GPU sync).
* ALWAYS flushes batched dispatches + deferred destroys to keep peak GPU
* memory low (prevents iOS Safari jetsam kills). */
private endStage;
/** Last timing report from generate(), available after each call. */
lastTimings: {
name: string;
ms: number;
}[];
/** Print timing summary to console and store for external access. */
private printTimings;
/** Capture a GPU buffer's contents as a named debug activation. No-op when debugCapture is off. */
private captureDebug;
/** Initialize WebGPU device and compile shaders. */
init(): Promise<void>;
/** Load model weights from ONNX file and voices from NPZ. */
loadModel(onnxUrl: string, voicesUrl: string): Promise<void>;
/**
* Build aliases from canonical (mini) onnx:: weight names to actual names.
*
* Different model sizes (mini, nano) use different numeric IDs for onnx:: weights.
* MatMul weights are always 10 in the same order, so simple positional mapping works.
*
* LSTM weights vary: mini has 6 LSTMs (18 weights), nano has 5 LSTMs (15 weights).
* We map by semantic group rather than simple position:
* Group 0: text encoder LSTM (3 weights)
* Group 1..N: predictor LSTMs (3 weights each, N varies by model)
* Group N+1: duration LSTM (3 weights)
* Group N+2: shared LSTM (3 weights)
*/
private buildOnnxAliases;
/**
* Detect model dimensions from loaded weight shapes.
* This allows the engine to work with any model size (mini, nano, etc.)
* without hardcoding dimensions.
*/
private detectDimensions;
/** Try to get a weight by name, checking aliases. Returns null if not found. */
private tryGetWeight;
/** Run TTS inference. */
generate(inputIds: number[], voice?: string, speed?: number, textLength?: number, // Raw text character count for voice style selection
onProgress?: (stage: string) => void): Promise<{
waveform: Float32Array;
duration: Int32Array;
}>;
/** CPU LSTM implementation for debugging/verification. */
private cpuLSTM;
private createBuffer;
private createEmptyBuffer;
private createUniformBuffer;
/** Destroy all pending uniform buffers (call after command submit). */
private flushUniformBuffers;
/** Get a weight tensor, throwing a descriptive error if missing.
* Handles cross-model weight name differences: micro/nano models may store
* some weights with/without `_quantized` suffix depending on their quantization.
* Also checks onnx:: weight aliases for cross-model-size compatibility. */
private requireWeight;
/** Start batching — flush any pending encoder + uniform buffers. */
private beginBatch;
/** End batch — flush encoder + uniform buffers. */
private endBatch;
/** Flush batch (alias). */
private flushBatch;
/** Submit batch — flush encoder + uniform buffers. */
private submitBatch;
/**
* Execute a dispatch on a pipeline. Dispatches are batched into a shared
* command encoder to reduce iOS Safari crashes (177 submits → ~5).
* The shared encoder is flushed at readBuffer boundaries.
*/
private dispatchSingle;
/** Queue a buffer for destruction after the next shared encoder flush.
* Use this instead of buffer.destroy() when batching dispatches. */
private deferDestroy;
/** Flush the shared encoder — submit all batched dispatches and destroy deferred buffers. */
private flushSharedEncoder;
/** Flush uniform buffers and shared encoder. */
private flushBatchEncoder;
/** Get a buffer from the pool (or allocate new) — same size buffers are reused. */
private poolGet;
/** Queue a buffer to return to pool after the next shared encoder flush.
* Unlike deferDestroy, the buffer is kept alive for reuse. */
private poolReturn;
/** Destroy all pooled buffers (call when pool is no longer needed). */
private destroyPool;
/** Read buffer contents back to CPU. */
private readBuffer;
private dispatchEmbedding;
private dispatchAdd;
/**
* Run the BERT/ALBERT encoder.
*
* Pipeline:
* 1. Embedding sum (word + position) is already computed → [seqLen, 128]
* 2. LayerNorm on embeddings (128-dim)
* 3. Linear projection 128 → 768 (embedding_hidden_mapping_in)
* 4. 12 iterations of shared ALBERT layer:
* a. Self-attention (Q/K/V projections, scaled dot-product, output projection)
* b. Residual + LayerNorm (attention)
* c. FFN: Linear 768→2048 (GELU) → Linear 2048→768
* d. Residual + LayerNorm (full layer)
* 5. Return [1, seqLen, 768]
*/
private runBertEncoder;
/**
* Run the text encoder.
*
* Pipeline:
* 1. Embedding lookup: input_ids → [seqLen, textEncChannels]
* 2. Transpose to channels-first: [textEncChannels, seqLen]
* 3. N× Conv1d(C, C, k=5, pad=2) + LayerNorm + LeakyReLU (N=numTextEncCnnBlocks)
* 4. Transpose to [seqLen, textEncChannels] for LSTM input
* 5. Bidirectional LSTM (hidden=lstmHidden) → [seqLen, 2, lstmHidden]
*
* Returns buffer with shape [seqLen, 2, lstmHidden] (= [seqLen, lstmBidir] flattened)
*/
private runTextEncoder;
private dispatchLayerNorm;
private dispatchMatmul;
private dispatchGelu;
private dispatchMatmulGelu;
private dispatchMHA;
private dispatchAddPass;
private dispatchConv1d;
private dispatchTranspose;
private dispatchLeakyRelu;
private dispatchSigmoid;
private dispatchLSTM;
/**
* Run one AdaIN ResNet block.
*
* Pattern:
* norm1: InstanceNorm(x) → AdaIN(FC(style)) → LeakyReLU
* conv1: Conv1d(in→out, k=3, pad=1)
* norm2: InstanceNorm → AdaIN(FC(style)) → LeakyReLU
* conv2: Conv1d(out→out, k=3, pad=1)
* residual: x + conv2_out (with optional conv1x1 for channel change)
*
* Some blocks also have sigmoid → divide by 2 (block 0 in N/F0 predictor).
*/
private runAdaINResNetBlock;
/**
* Run one decoder AdaIN conv block.
*
* Pattern:
* norm1: InstanceNorm(input) + AdaIN(FC(style))
* conv1: Conv1d(in→out, k=3, pad=1) + LeakyReLU
* norm2: InstanceNorm + AdaIN(FC(style))
* conv2: Conv1d(out→out, k=3, pad=1) + LeakyReLU
* residual: conv1x1(in→out) + output
*/
private runDecoderBlock;
private buildDecodeInput;
/**
* Run one HiFi-GAN residual block with AdaIN and Snake activation.
*
* ONNX pattern per iteration i:
* InstanceNorm(input) → AdaIN1(style) → Snake(alpha1) → conv1(dilated)
* → InstanceNorm → AdaIN2(style) → Snake(alpha2) → conv2
* → Add(conv2, input) [simple residual, no alpha scaling]
*
* Snake activation: x + (1/alpha) * sin²(alpha * x)
* alpha1/alpha2 are Snake parameters [1, C, 1], NOT residual weights.
*/
private runHiFiGANResBlock;
private dispatchSnake;
private dispatchInstanceNorm;
private dispatchAdaIN;
private dispatchConvTranspose1d;
/**
* Depthwise ConvTranspose1d: each channel processed independently (groups=channels).
* Used for pool layers in N.1, F0.1, decode.3 that double temporal resolution.
* Weight shape: [channels, 1, kernel_size] stored as [channels * kernel_size].
*/
private dispatchDepthwiseConvTranspose1d;
/**
* Resize 1D: nearest-neighbor interpolation (typically 2x upsampling).
* Used on the residual path of blocks with pools.
*/
private dispatchResize1d;
/** Multiply all elements by a constant scale factor. */
private dispatchScale;
/** Concatenate two channel-first tensors along channel dimension. */
private dispatchConcatChannels;
/** AdaIN row-major: normed[rows, C] + style_fc[2*C] → output[rows, C]. */
private dispatchAdaINRowMajor;
/** Concat row-major A[rows, colsA] with broadcast B[colsB] → [rows, colsA+colsB]. */
private dispatchConcatBroadcast;
/** GPU length expansion: [seqLen, D] → [totalFrames, D] using duration cumsum. */
private dispatchExpandRowMajor;
/** GPU length expansion with transpose: [seqLen, D] row-major → [D, totalFrames] channel-first. */
private dispatchExpandChannelFirst;
/** GPU iSTFT synthesis: conv_post [22, genLen] → waveform [waveformLen]. */
private dispatchISTFT;
/** Reflection pad 1D along time dimension. */
private dispatchReflectionPad1d;
/** Alpha-weighted residual: output = current + alpha[ch] * residual */
private dispatchAlphaResidual;
private compileShaders;
/** Infer bind group layout from WGSL source by parsing @binding annotations. */
private inferBindGroupLayout;
/**
* Generate source excitation signal for noise injection.
*
* Pipeline (computed on CPU — cumulative sum is sequential):
* 1. Read F0_proj from GPU [1, f0Length]
* 2. Upsample F0 to waveform rate via nearest-neighbor [waveLen]
* 3. For 9 harmonics: cumsum(k * f0 / sr) → sin(2π * phase)
* Voiced (f0 > 0): scaled sin wave. Unvoiced: Gaussian noise.
* 4. Linear(9→1) + bias + tanh → [waveLen, 1]
* 5. Edge-pad 10 → [waveLen + 20]
* 6. Forward STFT via Conv1d(1→11, k=20, s=5) for real + imag → magnitude + phase
* 7. Concat [magnitude, phase] → [22, stftLen]
* 8. Upload to GPU
*/
private generateSourceExcitation;
/** Check if model weights are loaded (GPU buffers alive). */
get weightsLoaded(): boolean;
/** Check if CPU weight cache is available for re-upload. */
get hasCachedWeights(): boolean;
/** Destroy all GPU resources. */
destroy(): void;
}