UNPKG

kitten-tts-webgpu

Version:

Run Kitten TTS (80M) locally in the browser via WebGPU. One function call: textToSpeech('Hello!') → WAV blob.

1,540 lines (1,385 loc) 3.02 MB
function l4() { typeof ReadableStream < "u" && !(Symbol.asyncIterator in ReadableStream.prototype) && (ReadableStream.prototype[Symbol.asyncIterator] = async function* () { const k = this.getReader(); try { for (; ; ) { const { value: u, done: B } = await k.read(); if (B) break; yield u; } } finally { k.releaseLock(); } }); } const ja = { symbols: Array.from("$;:,.!?¡¿—…“«»”„ ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"), voiceAliases: { Bella: "expr-voice-2-f", Jasper: "expr-voice-2-m", Luna: "expr-voice-3-f", Bruno: "expr-voice-3-m", Rosie: "expr-voice-4-f", Hugo: "expr-voice-4-m", Kiki: "expr-voice-5-f", Leo: "expr-voice-5-m" }, sampleRate: 24e3 }; class d4 { buffer; view; constructor(u) { this.buffer = new Uint8Array(u), this.view = new DataView(u); } /** Parse all initializer tensors from the ONNX model. */ parseInitializers() { const u = /* @__PURE__ */ new Map(), B = this.findField(this.buffer, 0, this.buffer.length, 7); if (!B) throw new Error("Could not find graph in ONNX model"); let l = B.start; for (; l < B.end; ) { const E = this.readTag(l); if (!E) break; if (E.fieldNumber === 5 && E.wireType === 2) { const D = this.readVarint(E.dataStart), m = D.end, H = m + D.value, Z = this.parseTensorProto(m, H); Z && Z.name && u.set(Z.name, Z), l = H; } else l = this.skipField(E); } return u; } parseTensorProto(u, B) { let l = ""; const E = []; let D = 0, m = null, H = null, Z = null, O = null, F = u; for (; F < B; ) { const T = this.readTag(F); if (!T) break; switch (T.fieldNumber) { case 1: if (T.wireType === 0) { const K = this.readVarint(T.dataStart); E.push(K.value), F = K.end; } else if (T.wireType === 2) { const K = this.readVarint(T.dataStart); let J = K.end; const q = J + K.value; for (; J < q; ) { const eA = this.readVarint(J); E.push(eA.value), J = eA.end; } F = q; } else F = this.skipField(T); break; case 2: { const K = this.readVarint(T.dataStart); D = K.value, F = K.end; } break; case 4: if (T.wireType === 2) { const K = this.readVarint(T.dataStart), J = K.end, q = this.buffer.slice(J, J + K.value); H = new Float32Array(q.buffer, 0, K.value / 4), F = J + K.value; } else F = this.skipField(T); break; case 5: if (T.wireType === 2) { const K = this.readVarint(T.dataStart), J = K.end, q = J + K.value, eA = []; let iA = J; for (; iA < q; ) { const $ = this.readVarint(iA); eA.push($.value | 0), iA = $.end; } Z = new Int32Array(eA), F = q; } else if (T.wireType === 0) { const K = this.readVarint(T.dataStart); Z = new Int32Array([K.value]), F = K.end; } else F = this.skipField(T); break; case 7: if (T.wireType === 2) { const K = this.readVarint(T.dataStart), J = K.end, q = this.buffer.slice(J, J + K.value); O = new BigInt64Array(q.buffer, 0, K.value / 8), F = J + K.value; } else F = this.skipField(T); break; case 8: { const K = this.readVarint(T.dataStart), J = this.buffer.subarray(K.end, K.end + K.value); l = new TextDecoder().decode(J), F = K.end + K.value; } break; case 9: { const K = this.readVarint(T.dataStart); m = this.buffer.subarray(K.end, K.end + K.value), F = K.end + K.value; } break; default: F = this.skipField(T); } } if (!l) return null; let X; return m ? X = m : H ? X = new Uint8Array(H.buffer, H.byteOffset, H.byteLength) : Z ? X = new Uint8Array(Z.buffer, Z.byteOffset, Z.byteLength) : O ? X = new Uint8Array(O.buffer, O.byteOffset, O.byteLength) : X = new Uint8Array(0), { name: l, dims: E, dataType: D, rawData: X }; } readTag(u) { if (u >= this.buffer.length) return null; const B = this.readVarint(u), l = B.value; return { fieldNumber: l >>> 3, wireType: l & 7, dataStart: B.end }; } readVarint(u) { let B = 0, l = 0, E = u; for (; E < this.buffer.length; ) { const D = this.buffer[E]; if (B |= (D & 127) << l, E++, (D & 128) === 0 || (l += 7, l > 35)) break; } return { value: B, end: E }; } skipField(u) { switch (u.wireType) { case 0: return this.readVarint(u.dataStart).end; case 1: return u.dataStart + 8; case 2: { const B = this.readVarint(u.dataStart); return B.end + B.value; } case 5: return u.dataStart + 4; default: throw new Error(`Unknown wire type: ${u.wireType}`); } } findField(u, B, l, E) { let D = B; for (; D < l; ) { const m = this.readTag(D); if (!m) break; if (m.fieldNumber === E && m.wireType === 2) { const H = this.readVarint(m.dataStart); return { start: H.end, end: H.end + H.value }; } D = this.skipField(m); } return null; } } function qa(k) { const u = new Float32Array(k.length); for (let B = 0; B < k.length; B++) { const l = k[B], E = l >> 15 & 1, D = l >> 10 & 31, m = l & 1023; D === 0 ? u[B] = (E ? -1 : 1) * Math.pow(2, -14) * (m / 1024) : D === 31 ? u[B] = m === 0 ? E ? -1 / 0 : 1 / 0 : NaN : u[B] = (E ? -1 : 1) * Math.pow(2, D - 15) * (1 + m / 1024); } return u; } async function h4(k) { const u = new Uint8Array(k), B = /* @__PURE__ */ new Map(); let l = 0; for (; l < u.length - 4 && !(u[l] !== 80 || u[l + 1] !== 75 || u[l + 2] !== 3 || u[l + 3] !== 4); ) { const E = new DataView(k, l), D = E.getUint16(8, !0); let m = E.getUint32(18, !0); const H = E.getUint16(26, !0), Z = E.getUint16(28, !0), O = new TextDecoder().decode(u.subarray(l + 30, l + 30 + H)); if (m === 4294967295) { let X = l + 30 + H; const T = X + Z; for (; X + 4 <= T; ) { const K = new DataView(k, X), J = K.getUint16(0, !0), q = K.getUint16(2, !0); if (J === 1 && q >= 16) { const eA = K.getUint32(12, !0), iA = K.getUint32(16, !0); m = eA + iA * 4294967296; break; } X += 4 + q; } } const F = l + 30 + H + Z; if (O.endsWith(".npy") && D === 0) { const X = u.subarray(F, F + m), T = w4(X.buffer, X.byteOffset), K = O.replace(".npy", ""); B.set(K, T); } l = F + m; } return B; } function w4(k, u) { const B = E4(k, u); return { shape: B.shape, data: B.data }; } function E4(k, u = 0) { const B = new Uint8Array(k, u); if (B[0] !== 147 || B[1] !== 78) throw new Error("Invalid .npy magic number"); const l = B[6]; let E, D; l === 1 ? (E = new DataView(k, u + 8).getUint16(0, !0), D = 10) : (E = new DataView(k, u + 8).getUint32(0, !0), D = 12); const m = new TextDecoder().decode(B.subarray(D, D + E)), H = D + E, Z = m.match(/shape['"]\s*:\s*\(([^)]*)\)/), O = Z ? Z[1].split(",").map((J) => J.trim()).filter((J) => J).map(Number) : [], F = m.match(/descr['"]\s*:\s*'([^']*)'/), X = F ? F[1] : "<f4", T = O.length === 0 ? 1 : O.reduce((J, q) => J * q, 1); let K; if (X === "<f4" || X === "=f4" || X === "float32") { const J = new Uint8Array(k, u + H, T * 4), q = new Uint8Array(T * 4); q.set(J), K = new Float32Array(q.buffer); } else if (X === "<f2" || X === "=f2" || X === "float16") { const J = new Uint8Array(k, u + H, T * 2), q = new Uint8Array(T * 2); q.set(J); const eA = new Uint16Array(q.buffer); K = qa(eA); } else if (X === "<i8" || X === "=i8" || X === "int64") { const J = new Uint8Array(k, u + H, T * 8), q = new Uint8Array(T * 8); q.set(J); const eA = new DataView(q.buffer); K = new Float32Array(T); for (let iA = 0; iA < T; iA++) K[iA] = eA.getInt32(iA * 8, !0); } else throw new Error(`Unsupported .npy dtype: ${X}`); return { shape: O, data: K, dtype: X }; } const Q4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> embeddings: array<f32>; @group(0) @binding(1) var<storage, read> input_ids: array<i32>; @group(0) @binding(2) var<storage, read_write> output: array<f32>; struct Params { seq_len: u32, embed_dim: u32, vocab_size: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let seq_idx = idx / params.embed_dim; let dim_idx = idx % params.embed_dim; if (seq_idx >= params.seq_len) { return; } let token_id = input_ids[seq_idx]; let embed_offset = u32(token_id) * params.embed_dim + dim_idx; output[idx] = embeddings[embed_offset]; } ` ), G4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read> gamma: array<f32>; @group(0) @binding(2) var<storage, read> beta: array<f32>; @group(0) @binding(3) var<storage, read_write> output: array<f32>; struct Params { batch_size: u32, hidden_size: u32, eps: f32, } @group(0) @binding(4) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let batch_idx = gid.x; if (batch_idx >= params.batch_size) { return; } let offset = batch_idx * params.hidden_size; // Compute mean var sum = 0.0; for (var i = 0u; i < params.hidden_size; i++) { sum += input[offset + i]; } let mean = sum / f32(params.hidden_size); // Compute variance var var_sum = 0.0; for (var i = 0u; i < params.hidden_size; i++) { let diff = input[offset + i] - mean; var_sum += diff * diff; } let variance = var_sum / f32(params.hidden_size); let inv_std = 1.0 / sqrt(variance + params.eps); // Normalize for (var i = 0u; i < params.hidden_size; i++) { output[offset + i] = (input[offset + i] - mean) * inv_std * gamma[i] + beta[i]; } } ` ), p4 = ( /* wgsl */ ` // Tiled matmul with shared memory. TILE=16, each workgroup computes a 16×16 output tile. // Reduces global memory reads by factor of TILE compared to naive approach. const TILE: u32 = 16u; @group(0) @binding(0) var<storage, read> A: array<f32>; @group(0) @binding(1) var<storage, read> B: array<f32>; @group(0) @binding(2) var<storage, read> bias: array<f32>; @group(0) @binding(3) var<storage, read_write> output: array<f32>; struct Params { M: u32, // rows of A / output K: u32, // cols of A / rows of B N: u32, // cols of B / output use_bias: u32, } @group(0) @binding(4) var<uniform> params: Params; var<workgroup> tileA: array<f32, 256>; // 16×16 var<workgroup> tileB: array<f32, 256>; // 16×16 @compute @workgroup_size(16, 16) fn main( @builtin(global_invocation_id) gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, ) { let row = gid.x; let col = gid.y; let lr = lid.x; let lc = lid.y; var sum = 0.0; let numTiles = (params.K + TILE - 1u) / TILE; for (var t = 0u; t < numTiles; t++) { // Load tile of A: rows [row_base..+16], cols [t*16..+16] let aCol = t * TILE + lc; if (row < params.M && aCol < params.K) { tileA[lr * TILE + lc] = A[row * params.K + aCol]; } else { tileA[lr * TILE + lc] = 0.0; } // Load tile of B: rows [t*16..+16], cols [col_base..+16] let bRow = t * TILE + lr; if (bRow < params.K && col < params.N) { tileB[lr * TILE + lc] = B[bRow * params.N + col]; } else { tileB[lr * TILE + lc] = 0.0; } workgroupBarrier(); // Accumulate dot product from shared memory for (var k = 0u; k < TILE; k++) { sum += tileA[lr * TILE + k] * tileB[k * TILE + lc]; } workgroupBarrier(); } if (row < params.M && col < params.N) { if (params.use_bias != 0u) { sum += bias[col]; } output[row * params.N + col] = sum; } } ` ), m4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [C_in, L] @group(0) @binding(1) var<storage, read> weight: array<f32>; // [C_out, C_in, K] @group(0) @binding(2) var<storage, read> bias: array<f32>; // [C_out] @group(0) @binding(3) var<storage, read_write> output: array<f32>; // [C_out, L_out] struct Params { in_channels: u32, out_channels: u32, kernel_size: u32, input_length: u32, output_length: u32, padding: u32, stride: u32, dilation: u32, use_bias: u32, } @group(0) @binding(4) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let out_ch = idx / params.output_length; let out_pos = idx % params.output_length; if (out_ch >= params.out_channels) { return; } var sum = 0.0; for (var ic = 0u; ic < params.in_channels; ic++) { for (var k = 0u; k < params.kernel_size; k++) { let in_pos_raw = i32(out_pos * params.stride) + i32(k * params.dilation) - i32(params.padding); if (in_pos_raw >= 0 && u32(in_pos_raw) < params.input_length) { let w_idx = out_ch * params.in_channels * params.kernel_size + ic * params.kernel_size + k; let in_idx = ic * params.input_length + u32(in_pos_raw); sum += input[in_idx] * weight[w_idx]; } } } if (params.use_bias != 0u) { sum += bias[out_ch]; } output[out_ch * params.output_length + out_pos] = sum; } ` ), D4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [C, L] @group(0) @binding(1) var<storage, read_write> output: array<f32>; // [C, L] struct Params { channels: u32, length: u32, eps: f32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let ch = gid.x; if (ch >= params.channels) { return; } let offset = ch * params.length; // Compute mean var sum = 0.0; for (var i = 0u; i < params.length; i++) { sum += input[offset + i]; } let mean = sum / f32(params.length); // Compute variance var var_sum = 0.0; for (var i = 0u; i < params.length; i++) { let diff = input[offset + i] - mean; var_sum += diff * diff; } let variance = var_sum / f32(params.length); let inv_std = 1.0 / sqrt(variance + params.eps); // Normalize (no scale/bias for instance norm in this model - AdaIN handles that) for (var i = 0u; i < params.length; i++) { output[offset + i] = (input[offset + i] - mean) * inv_std; } } ` ), v4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> normed: array<f32>; // [C, L] - instance-normed input @group(0) @binding(1) var<storage, read> style_fc: array<f32>; // [2*C] - first C = scale, second C = bias @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C, L] struct Params { channels: u32, length: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let ch = idx / params.length; let pos = idx % params.length; if (ch >= params.channels) { return; } // AdaIN: (1 + gamma) * normed + beta — the +1 offset is universal across all AdaIN blocks // style_fc layout: [scale_0..scale_{C-1}, bias_0..bias_{C-1}] let scale = style_fc[ch]; let bias = style_fc[params.channels + ch]; output[idx] = normed[idx] * (scale + 1.0) + bias; } ` ), x4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> normed: array<f32>; // [rows, C] @group(0) @binding(1) var<storage, read> style_fc: array<f32>; // [2*C] @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [rows, C] struct Params { channels: u32, total: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.total) { return; } // Row-major: channel = idx % channels let ch = idx % params.channels; let scale = style_fc[ch]; let bias = style_fc[params.channels + ch]; output[idx] = normed[idx] * (scale + 1.0) + bias; } ` ), M4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [C, L] @group(0) @binding(1) var<storage, read> alpha: array<f32>; // [C] (flattened from [1, C, 1]) @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C, L] struct Params { channels: u32, length: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let ch = idx / params.length; let pos = idx % params.length; if (ch >= params.channels) { return; } let x = input[idx]; let a = alpha[ch]; let sin_ax = sin(a * x); // Snake: x + (1/a) * sin²(a * x) output[idx] = x + sin_ax * sin_ax / a; } ` ), y4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { size: u32, alpha: f32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } let x = input[idx]; output[idx] = select(params.alpha * x, x, x >= 0.0); } ` ), Y4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { size: u32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } let x = input[idx]; // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) // Clamp tanh arg to prevent exp(2x) overflow in f32 (exp overflows at ~88.72) let c = 0.7978845608; // sqrt(2/pi) let inner = clamp(c * (x + 0.044715 * x * x * x), -44.0, 44.0); output[idx] = 0.5 * x * (1.0 + tanh(inner)); } ` ), N4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { size: u32 } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } output[idx] = tanh(input[idx]); } ` ), H4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { size: u32 } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } output[idx] = 1.0 / (1.0 + exp(-input[idx])); } ` ), P4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [C_in, L_in] @group(0) @binding(1) var<storage, read> weight: array<f32>; // [C_in, C_out, K] @group(0) @binding(2) var<storage, read> bias: array<f32>; // [C_out] @group(0) @binding(3) var<storage, read_write> output: array<f32>; // [C_out, L_out] struct Params { in_channels: u32, out_channels: u32, kernel_size: u32, input_length: u32, output_length: u32, stride: u32, padding: u32, use_bias: u32, } @group(0) @binding(4) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let out_ch = idx / params.output_length; let out_pos = idx % params.output_length; if (out_ch >= params.out_channels) { return; } var sum = 0.0; for (var ic = 0u; ic < params.in_channels; ic++) { for (var k = 0u; k < params.kernel_size; k++) { // ConvTranspose: output[out_pos] += input[in_pos] * weight[ic, out_ch, k] // where out_pos = in_pos * stride + k - padding // so in_pos = (out_pos + padding - k) / stride let numerator = i32(out_pos) + i32(params.padding) - i32(k); if (numerator >= 0 && u32(numerator) % params.stride == 0u) { let in_pos = u32(numerator) / params.stride; if (in_pos < params.input_length) { let w_idx = ic * params.out_channels * params.kernel_size + out_ch * params.kernel_size + k; let in_idx = ic * params.input_length + in_pos; sum += input[in_idx] * weight[w_idx]; } } } } if (params.use_bias != 0u) { sum += bias[out_ch]; } output[out_ch * params.output_length + out_pos] = sum; } ` ), z4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [channels, L_in] @group(0) @binding(1) var<storage, read> weight: array<f32>; // [channels, 1, K] = [channels * K] @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [channels, L_out] struct Params { channels: u32, kernel_size: u32, input_length: u32, output_length: u32, stride: u32, padding: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let ch = idx / params.output_length; let out_pos = idx % params.output_length; if (ch >= params.channels) { return; } var sum = 0.0; for (var k = 0u; k < params.kernel_size; k++) { let numerator = i32(out_pos) + i32(params.padding) - i32(k); if (numerator >= 0 && u32(numerator) % params.stride == 0u) { let in_pos = u32(numerator) / params.stride; if (in_pos < params.input_length) { let w_idx = ch * params.kernel_size + k; let in_idx = ch * params.input_length + in_pos; sum += input[in_idx] * weight[w_idx]; } } } output[ch * params.output_length + out_pos] = sum; } ` ), O4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [channels, L_in] @group(0) @binding(1) var<storage, read_write> output: array<f32>; // [channels, L_out] struct Params { channels: u32, input_length: u32, output_length: u32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let ch = idx / params.output_length; let out_pos = idx % params.output_length; if (ch >= params.channels) { return; } // Nearest neighbor: map output position to input position let in_pos = out_pos * params.input_length / params.output_length; output[ch * params.output_length + out_pos] = input[ch * params.input_length + in_pos]; } ` ), F4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { batch_size: u32, dim_size: u32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let batch_idx = gid.x; if (batch_idx >= params.batch_size) { return; } let offset = batch_idx * params.dim_size; // Find max for numerical stability var max_val = input[offset]; for (var i = 1u; i < params.dim_size; i++) { max_val = max(max_val, input[offset + i]); } // Compute exp and sum var exp_sum = 0.0; for (var i = 0u; i < params.dim_size; i++) { let e = exp(input[offset + i] - max_val); output[offset + i] = e; exp_sum += e; } // Normalize for (var i = 0u; i < params.dim_size; i++) { output[offset + i] /= exp_sum; } } ` ), W4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> Q: array<f32>; // [seq_len, num_heads, head_dim] @group(0) @binding(1) var<storage, read> K: array<f32>; // [seq_len, num_heads, head_dim] @group(0) @binding(2) var<storage, read> V: array<f32>; // [seq_len, num_heads, head_dim] @group(0) @binding(3) var<storage, read_write> output: array<f32>; // [seq_len, num_heads, head_dim] struct Params { seq_len: u32, num_heads: u32, head_dim: u32, scale: f32, // 1/sqrt(head_dim) } @group(0) @binding(4) var<uniform> params: Params; // Workgroup: one per (head, query_pos). Threads iterate over key positions. // We use a simple approach: each thread computes one output element (head_dim index). @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { // gid.x = dim_idx within head, gid.y = head_idx * seq_len + query_pos let dim_idx = gid.x; let head_query = gid.y; let head_idx = head_query / params.seq_len; let q_pos = head_query % params.seq_len; if (dim_idx >= params.head_dim || head_idx >= params.num_heads) { return; } let hd = params.head_dim; let nh = params.num_heads; let sl = params.seq_len; // Q vector for this (q_pos, head): Q[q_pos * nh * hd + head_idx * hd + ...] let q_base = q_pos * nh * hd + head_idx * hd; // Compute attention scores: dot(Q[q_pos, head], K[k_pos, head]) for all k_pos // Then softmax and weighted sum of V // Since we can't do cross-thread softmax easily, each thread computes full attention // for one output dimension. This is O(seq_len * head_dim) per thread but simple. // Step 1: Compute all attention scores (each thread does this redundantly) // For short sequences (< 512) this is fine var max_score = -1e10; for (var k = 0u; k < sl; k++) { let k_base = k * nh * hd + head_idx * hd; var score = 0.0; for (var d = 0u; d < hd; d++) { score += Q[q_base + d] * K[k_base + d]; } score *= params.scale; max_score = max(max_score, score); } // Step 2: Softmax var exp_sum = 0.0; var weighted_val = 0.0; for (var k = 0u; k < sl; k++) { let k_base = k * nh * hd + head_idx * hd; var score = 0.0; for (var d = 0u; d < hd; d++) { score += Q[q_base + d] * K[k_base + d]; } score *= params.scale; let w = exp(score - max_score); exp_sum += w; // Accumulate V[k_pos, head, dim_idx] weighted by attention let v_base = k * nh * hd + head_idx * hd; weighted_val += w * V[v_base + dim_idx]; } let out_idx = q_pos * nh * hd + head_idx * hd + dim_idx; output[out_idx] = weighted_val / exp_sum; } ` ), T4 = ( /* wgsl */ ` // Tiled matmul + GELU with shared memory. const TILE: u32 = 16u; @group(0) @binding(0) var<storage, read> A: array<f32>; @group(0) @binding(1) var<storage, read> B: array<f32>; @group(0) @binding(2) var<storage, read> bias: array<f32>; @group(0) @binding(3) var<storage, read_write> output: array<f32>; struct Params { M: u32, K: u32, N: u32, } @group(0) @binding(4) var<uniform> params: Params; var<workgroup> tileA: array<f32, 256>; var<workgroup> tileB: array<f32, 256>; @compute @workgroup_size(16, 16) fn main( @builtin(global_invocation_id) gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, ) { let row = gid.x; let col = gid.y; let lr = lid.x; let lc = lid.y; var sum = 0.0; let numTiles = (params.K + TILE - 1u) / TILE; for (var t = 0u; t < numTiles; t++) { let aCol = t * TILE + lc; if (row < params.M && aCol < params.K) { tileA[lr * TILE + lc] = A[row * params.K + aCol]; } else { tileA[lr * TILE + lc] = 0.0; } let bRow = t * TILE + lr; if (bRow < params.K && col < params.N) { tileB[lr * TILE + lc] = B[bRow * params.N + col]; } else { tileB[lr * TILE + lc] = 0.0; } workgroupBarrier(); for (var k = 0u; k < TILE; k++) { sum += tileA[lr * TILE + k] * tileB[k * TILE + lc]; } workgroupBarrier(); } if (row < params.M && col < params.N) { sum += bias[col]; // GELU activation (clamp tanh arg to prevent f32 exp overflow) let c = 0.7978845608; let x = sum; let inner = clamp(c * (x + 0.044715 * x * x * x), -44.0, 44.0); output[row * params.N + col] = 0.5 * x * (1.0 + tanh(inner)); } } ` ), K4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> a: array<f32>; @group(0) @binding(1) var<storage, read> b: array<f32>; @group(0) @binding(2) var<storage, read_write> output: array<f32>; struct Params { size: u32 } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } output[idx] = a[idx] + b[idx]; } ` ), Z4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { size: u32, _pad1: u32, scale: f32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; if (idx >= params.size) { return; } output[idx] = input[idx] * params.scale; } ` ), L4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> a: array<f32>; // [C_a, L] @group(0) @binding(1) var<storage, read> b: array<f32>; // [C_b, L] @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [C_a + C_b, L] struct Params { channels_a: u32, channels_b: u32, length: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let total = (params.channels_a + params.channels_b) * params.length; if (idx >= total) { return; } let ch = idx / params.length; let pos = idx % params.length; if (ch < params.channels_a) { output[idx] = a[ch * params.length + pos]; } else { output[idx] = b[(ch - params.channels_a) * params.length + pos]; } } ` ), X4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> a: array<f32>; // [rows, cols_a] @group(0) @binding(1) var<storage, read> b: array<f32>; // [cols_b] — broadcast to every row @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [rows, cols_a + cols_b] struct Params { rows: u32, cols_a: u32, cols_b: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let total_cols = params.cols_a + params.cols_b; let total = params.rows * total_cols; if (idx >= total) { return; } let row = idx / total_cols; let col = idx % total_cols; if (col < params.cols_a) { output[idx] = a[row * params.cols_a + col]; } else { output[idx] = b[col - params.cols_a]; } } ` ), S4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [channels, L_in] @group(0) @binding(1) var<storage, read_write> output: array<f32>; // [channels, L_out] struct Params { channels: u32, input_length: u32, pad_left: u32, pad_right: u32, } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let out_length = params.input_length + params.pad_left + params.pad_right; let ch = idx / out_length; let out_pos = idx % out_length; if (ch >= params.channels) { return; } var in_pos: u32; if (out_pos < params.pad_left) { // Reflected left: position 0 -> pad_left, position 1 -> pad_left-1, etc. in_pos = params.pad_left - out_pos; } else if (out_pos >= params.pad_left + params.input_length) { // Reflected right let overshoot = out_pos - params.pad_left - params.input_length; in_pos = params.input_length - 2u - overshoot; } else { in_pos = out_pos - params.pad_left; } output[ch * out_length + out_pos] = input[ch * params.input_length + in_pos]; } ` ), R4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> current: array<f32>; // conv2 output @group(0) @binding(1) var<storage, read> residual: array<f32>; // residual from previous iteration @group(0) @binding(2) var<storage, read> alpha: array<f32>; // [1, channels, 1] per-channel alpha @group(0) @binding(3) var<storage, read_write> output: array<f32>; struct Params { channels: u32, length: u32, } @group(0) @binding(4) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let ch = idx / params.length; if (ch >= params.channels) { return; } // output = current + alpha[ch] * residual output[idx] = current[idx] + alpha[ch] * residual[idx]; } ` ), U4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; @group(0) @binding(1) var<storage, read_write> output: array<f32>; struct Params { rows: u32, cols: u32 } @group(0) @binding(2) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let total = params.rows * params.cols; if (idx >= total) { return; } let row = idx / params.cols; let col = idx % params.cols; output[col * params.rows + row] = input[idx]; } ` ), V4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [seq_len, input_size] @group(0) @binding(1) var<storage, read> W: array<f32>; // [num_dir, input_size, 4*hidden] @group(0) @binding(2) var<storage, read> R: array<f32>; // [num_dir, hidden, 4*hidden] @group(0) @binding(3) var<storage, read> bias: array<f32>; // [num_dir, 8*hidden] @group(0) @binding(4) var<storage, read_write> output: array<f32>; // [seq_len, num_dir, hidden] struct Params { seq_len: u32, input_size: u32, hidden_size: u32, num_directions: u32, } @group(0) @binding(5) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let h_idx = gid.x; // which hidden unit let dir = gid.y; // 0=forward, 1=backward // NOTE: no early return — all threads in workgroup must reach storageBarrier() let is_valid = h_idx < params.hidden_size && dir < params.num_directions; let H = params.hidden_size; let H4 = H * 4u; let IS = params.input_size; let SL = params.seq_len; // Use safe indices for inactive threads (they won't write) let safe_h = select(0u, h_idx, is_valid); let safe_dir = select(0u, dir, is_valid); // Gate offsets within 4*hidden: i=0, o=1, f=2, c=3 (ONNX order) let gate_i = safe_h; let gate_o = H + safe_h; let gate_f = 2u * H + safe_h; let gate_c = 3u * H + safe_h; // Bias offsets: [Wb_i, Wb_o, Wb_f, Wb_c, Rb_i, Rb_o, Rb_f, Rb_c] let bias_base = safe_dir * 8u * H; var b_wi = 0.0; var b_wo = 0.0; var b_wf = 0.0; var b_wc = 0.0; var b_ri = 0.0; var b_ro = 0.0; var b_rf = 0.0; var b_rc = 0.0; if (is_valid) { b_wi = bias[bias_base + safe_h]; b_wo = bias[bias_base + H + safe_h]; b_wf = bias[bias_base + 2u * H + safe_h]; b_wc = bias[bias_base + 3u * H + safe_h]; b_ri = bias[bias_base + 4u * H + safe_h]; b_ro = bias[bias_base + 5u * H + safe_h]; b_rf = bias[bias_base + 6u * H + safe_h]; b_rc = bias[bias_base + 7u * H + safe_h]; } var h_val = 0.0; // hidden state for this unit var c_val = 0.0; // cell state for this unit // Weight base offsets for this direction // W: [num_dir, IS, 4H] — flat stride: dir * IS * H4 // R: [num_dir, H, 4H] — flat stride: dir * H * H4 let w_base = safe_dir * IS * H4; let r_base = safe_dir * H * H4; for (var step = 0u; step < SL; step++) { if (is_valid) { // Forward: t=step, Backward: t=SL-1-step let t = select(SL - 1u - step, step, safe_dir == 0u); // Compute gates from input: sum over input_size var gi = b_wi + b_ri; var go = b_wo + b_ro; var gf = b_wf + b_rf; var gc = b_wc + b_rc; // Input contribution: W[dir, j, gate*H+h_idx] — layout [IS, 4H] // x[j] * W[w_base + j * H4 + gate_offset] for (var j = 0u; j < IS; j++) { let x_val = input[t * IS + j]; let w_off = w_base + j * H4; gi += x_val * W[w_off + gate_i]; go += x_val * W[w_off + gate_o]; gf += x_val * W[w_off + gate_f]; gc += x_val * W[w_off + gate_c]; } // Recurrence contribution: R[dir, j, gate*H+h_idx] — layout [H, 4H] // h_prev[j] * R[r_base + j * H4 + gate_offset] if (step > 0u) { let prev_t = select(SL - step, step - 1u, safe_dir == 0u); let prev_base = prev_t * params.num_directions * H + safe_dir * H; for (var j = 0u; j < H; j++) { let h_prev = output[prev_base + j]; let r_off = r_base + j * H4; gi += h_prev * R[r_off + gate_i]; go += h_prev * R[r_off + gate_o]; gf += h_prev * R[r_off + gate_f]; gc += h_prev * R[r_off + gate_c]; } } // Apply activations // Clamp sigmoid inputs to avoid exp overflow (exp(88.72) > f32 max) let i_gate = 1.0 / (1.0 + exp(-clamp(gi, -44.0, 44.0))); // sigmoid let o_gate = 1.0 / (1.0 + exp(-clamp(go, -44.0, 44.0))); let f_gate = 1.0 / (1.0 + exp(-clamp(gf, -44.0, 44.0))); // Clamp tanh inputs: tanh uses exp(2x), so |x| > 44 → exp(88) → Inf → NaN let c_gate = tanh(clamp(gc, -44.0, 44.0)); c_val = f_gate * c_val + i_gate * c_gate; h_val = o_gate * tanh(clamp(c_val, -44.0, 44.0)); // Write output: [t, dir, h_idx] → flat: t * num_dir * H + dir * H + h_idx output[t * params.num_directions * H + safe_dir * H + safe_h] = h_val; } // Barrier: all threads (active and inactive) must reach this point // storageBarrier() ensures visibility of storage buffer writes across threads in the workgroup storageBarrier(); } } ` ), J4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [seq_len, dim] @group(0) @binding(1) var<storage, read> cumsum: array<u32>; // [seq_len] prefix sum of durations @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [total_frames, dim] struct Params { seq_len: u32, dim: u32, total_frames: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let total = params.total_frames * params.dim; if (idx >= total) { return; } let frame = idx / params.dim; let d = idx % params.dim; // Binary search: find token i where cumsum[i-1] <= frame < cumsum[i] var lo: u32 = 0u; var hi: u32 = params.seq_len; while (lo < hi) { let mid = (lo + hi) / 2u; if (cumsum[mid] <= frame) { lo = mid + 1u; } else { hi = mid; } } let token = lo; output[idx] = input[token * params.dim + d]; } ` ), j4 = ( /* wgsl */ ` @group(0) @binding(0) var<storage, read> input: array<f32>; // [seq_len, dim] row-major @group(0) @binding(1) var<storage, read> cumsum: array<u32>; // [seq_len] prefix sum of durations @group(0) @binding(2) var<storage, read_write> output: array<f32>; // [dim, total_frames] channel-first struct Params { seq_len: u32, dim: u32, total_frames: u32, } @group(0) @binding(3) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let idx = gid.x; let total = params.total_frames * params.dim; if (idx >= total) { return; } // Output layout: [dim, total_frames] — idx = channel * total_frames + frame let channel = idx / params.total_frames; let frame = idx % params.total_frames; // Binary search: find token i where cumsum[i-1] <= frame < cumsum[i] var lo: u32 = 0u; var hi: u32 = params.seq_len; while (lo < hi) { let mid = (lo + hi) / 2u; if (cumsum[mid] <= frame) { lo = mid + 1u; } else { hi = mid; } } let token = lo; output[idx] = input[token * params.dim + channel]; } ` ), q4 = ( /* wgsl */ ` // iSTFT synthesis: conv_post [22, genLength] → waveform [waveformLength] // Gather-based ConvTranspose: each thread computes one output sample // Fuses: magnitude/phase split, exp, sin(sin(ph)), cos(sin(ph)), ConvTranspose scatter @group(0) @binding(0) var<storage, read> conv_post: array<f32>; // [22, gen_length] @group(0) @binding(1) var<storage, read> weight_real: array<f32>; // [11, 20] @group(0) @binding(2) var<storage, read> weight_imag: array<f32>; // [11, 20] @group(0) @binding(3) var<storage, read_write> output: array<f32>; // [waveform_length] struct Params { gen_length: u32, waveform_length: u32, bins: u32, // 11 kernel_size: u32, // 20 stride: u32, // 5 } @group(0) @binding(4) var<uniform> params: Params; @compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let out_pos = gid.x; if (out_pos >= params.waveform_length) { return; } var sum: f32 = 0.0; // For each kernel tap, check if this output position has a contribution for (var k: u32 = 0u; k < params.kernel_size; k = k + 1u) { if (out_pos < k) { continue; } let rem = out_pos - k; if (rem % params.stride != 0u) { continue; } let t = rem / params.stride; if (t >= params.gen_length) { continue; } // For each frequency bin, compute magnitude/phase and accumulate for (var b: u32 = 0u; b < params.bins; b = b + 1u) { let mag_val = conv_post[b * params.gen_length + t]; let ph_val = conv_post[(b + params.bins) * params.gen_length + t]; let mag = exp(mag_val); let sin_ph = sin(ph_val); let real_comp = mag * cos(sin_ph); let imag_comp = mag * sin(sin_ph); sum += real_comp * weight_real[b * params.kernel_size + k] - imag_comp * weight_imag[b * params.kernel_size + k]; } } output[out_pos] = sum; } ` ); class _4 { device; weights = /* @__PURE__ */ new Map(); /** Aliases from canonical (mini) onnx:: weight names to actual weight names in the loaded model. */ weightAliases = /* @__PURE__ */ new Map(); pipelines = /* @__PURE__ */ new Map(); voices = /* @__PURE__ */ new Map(); // [400, 256] per voice // ── Dynamic model dimensions (detected from weight shapes after loading) ── /** LSTM hidden size for text encoder / predictor / duration / shared LSTMs (mini=256, nano=64). */ lstmHidden = 256; /** Bidirectional LSTM output size = 2 * lstmHidden (mini=512, nano=128). */ lstmBidir = 512; /** Text encoder embedding / CNN channel dim (mini=512, nano varies). Detected from text_encoder embedding weight. */ textEncChannels = 512; /** Style embedding total dimension from voices.npz (typically 256 for all model sizes). */ styleDim = 256; /** Style predictor half-dim = styleDim / 2 (typically 128). */ styleHalf = 128; /** LSTM input size for predictor/duration/shared = lstmBidir + styleHalf. */ lstmInputSize = 640; /** BERT embedding dim (mini=128, nano may differ). */ bertEmbedDim = 128; /** BERT hidden size (mini=768, nano may differ). */ bertHiddenSize = 768; /** BERT number of attention heads (mini=12). */ bertNumHeads = 12; /** BERT head dim = bertHiddenSize / bertNumHeads. */ bertHeadDim = 64; /** BERT FFN intermediate dim (mini=2048). */ bertFfnDim = 2048; /** BERT number of layer iterations (mini=12). */ bertNumLayers = 12; /** Number of predictor LSTM+FC pairs (mini=3, nano=2). */ numPredLstmPairs = 3; /** BERT encoder projection output dim (= lstmBidir, mini=512). */ bertProjDim = 512; /** Number of text encoder CNN blocks (mini=3, nano=2). */ numTextEncCnnBlocks = 3; /** Decoder encode output channels (mini=1024, nano=256). Detected from weight shapes. */ decEncodeOutCh = 1024; /** Decoder decode.0-2 output channels (mini=1024, nano=256). */ decDecodeOutCh = 1024; /** Decoder decode.3 output channels (mini=512, nano=256). */ decDecode3OutCh = 512; /** HiFi-GAN ups.0 output channels (mini=256, nano=128). Detected from weight shapes. */ hifiUps0OutCh = 256; /** HiFi-GAN ups.1 output channels (mini=128, nano=64). */ hifiUps1OutCh = 128; /** N/F0 predictor block0 output channels (mini=512, nano=128). */ predBlock0OutCh = 512; /** N/F0 predictor block1+ output channels (mini=256, nano=64). */ predBlock1OutCh = 256; config; /** Uniform buffers created during dispatch, cleaned up after submit. */ pendingUniformBuffers = []; /** CPU-side weight cache for re-uploading after freeGpuWeights(). * Populated during loadModel() so we can free/re-upload GPU buffers * between generations to prevent iOS Safari jetsam kills. */ weightCache = /* @__PURE__ */ new Map(); /** Pending command buffers for batch submission. */ pendingCommandBuffers = []; /** Shared command encoder for batching dispatches (reduces iOS Safari crashes). * Dispatches are recorded into this encoder and only submitted at readBuffer * boundaries or when flushSharedEncoder() is called explicitly. */ sharedEncoder = null; /** Buffers to destroy after the shared encoder is submitted. */ deferredDestroys = []; /** Buffer pool: reuse GPU buffers by byte size instead of destroy+reallocate. * Key insight: reusing buffers avoids Metal accumulating dead references to * destroyed buffers, which is the root cause of iOS jetsam kills. */ bufferPool = /* @__PURE__ */ new Map(); /** Buffers to return to pool (not destroy) after the next shared encoder flush. */ deferredPoolReturns = []; /** Cached CPU copies of sin generator weights (avoid readBuffer every inference). */ sinGenWeights = null; /** Debug mode: when true, intermediate activations are captured for comparison. */ debugCapture = !1; /** Captured activations (name → {data, shape}). Only populated when debugCapture=true. */ debugActivations = /* @__PURE__ */ new Map(); debugBertBuffers = null; /** Performance profiling: when true, logs timing per pipeline stage. */ profile = !1; timings = /* @__PURE__ */ new Map(); _stageStart = 0; constructor(u = ja) { this.config = u; } /** Start timing a pipeline stage. Call endStage() to record. */ startStage() { this.profile && (this._stageStart = performance.now()); } /** End timing and record the stage duration (includes GPU sync). * ALWAYS flushes batched dispatches + deferred destroys to keep peak GPU * memory low (prevents iOS Safari jetsam kills). */ async endStage(u) { if (this.flushBatchEncoder(), this.profile) { await this.device.queue.onSubmittedWorkDone(); const B = performance.now() - this._stageStart; this.timings.set(u, B); } } /** Last timing report from generate(), available after each call. */ lastTimings = []; /** Print timing summary to console and store for external access. */ printTimings() { if (!this.profile) return; this.lastTimings = []; let u = 0; const B = []; for (const [l, E] of this.timings) u += E, this.lastTimings.push({ name: l, ms: E }), B.push(` ${l.padEnd(35)} ${E.toFixed(1).padStart(8)} ms`); console.log(` [KittenTTS] ── Timing Report ──`); for (const l of B) console.log(l); console.log(` ${"─".repeat(45)}`), console.log(` ${"TOTAL".padEnd(35)} ${u.toFixed(1).padStart(8)} ms`), this.timings.clear(); } /** Capture a GPU buffer's contents as a named debug activation. No-op when debugCapture is off. */ async captureDebug(u, B, l) { if (!this.debugCapture) return; this.endBatch(); const E = l.reduce((O, F) => O * F, 1), D = await this.readBuffer(B, E); this.debugActivations.set(u, { data: D, shape: l }); let m = 1 / 0, H = -1 / 0, Z = 0; for (let O = 0; O < D.length; O++) { if (isNaN(D[O])) { Z++; continue; } D[O] < m && (m = D[O]), D[O] > H && (H = D[O]); } console.log(`[DEBUG] Captured ${u}: shape=[${l}], range=[${m}, ${H}], NaN=${Z}/${D.length}`); } /** Initialize WebGPU device and compile shaders. */ async init() { const u = await navigator.gpu?.requestAdapter(); if (!u) throw new Error("WebGPU not available"); const B = u.limits, l = Math.min(256 * 1024 * 1024, B.maxStorageBufferBindingSize), E = Math.min(256 * 1024 * 1024, B.maxBufferSize); console.log(`[KittenTTS] Adapter limits: maxStorageBuffer=${l}, maxBuffer=${E}`), this.device = await u.requestDevice({ requiredLimits: { maxStorageBufferBindingSize: l, maxBufferSize: E } }), this.device.lost.then((D) => { console.error(`[KittenTTS] Device lost: ${D.reason} — ${D.message}`), window.dispatchEvent(new CustomEvent("webgpu-device-lost", { detail: D })); }), this.device.addEventListener("uncapturederror", (D) => { const m = D; console.error(`[KittenTTS] GPU error: ${m.error.message}`), window.dispatchEvent(new CustomEvent("webgpu-error", { detail: m.error.message })); }), this.compileShaders(), console.log("[KittenTTS] WebGPU device initialized"); } /** Load model weights from ONNX file and voices from NPZ. */ async loadModel(u, B) { console.log("[KittenTTS] Loading model..."); const l = await fetch(u).then((O) => O.arrayBuffer()), D = new d4(l).parseInitializers(); console.log(`[KittenTTS] Parsed ${D.size} weight tensors`); const m = (O) => { let F; O.endsWith("_quantized") ? F = O.slice(0, -10) : F = O; const X = D.get(`${F}_scale`), T = D.get(`${F}_zero_point`); let K = new Float32Array([1]), J = new Int32Array([0]); if (X && X.rawData.length >= 4) { const q = X.rawData.length / 4, eA = new Uint8Array(q * 4); eA.set(X.rawData.subarray(0, q * 4)), K = new Float32Array(eA.buffer); } if (T && T.rawData.length >= 1) { const q = T.dims.length === 0 ? 1 : T.dims.reduce(($, bA) => $