UNPKG

parakeet.js

Version:

NVIDIA Parakeet speech recognition for the browser (WebGPU/WASM) powered by ONNX Runtime Web.

490 lines (430 loc) 19.4 kB
import { initOrt } from './backend.js'; import { ParakeetTokenizer } from './tokenizer.js'; import { OnnxPreprocessor } from './preprocessor.js'; /** * Lightweight Parakeet model wrapper designed for browser usage. * Currently supports the *combined* decoder_joint-model ONNX (encoder+decoder+joiner in ' * transformerjs' style) exported by parakeet TDT. * * NOTE: This is an *early* scaffold – the `transcribe` method is TODO. */ export class ParakeetModel { constructor({ tokenizer, encoderSession, joinerSession, preprocessor, ort, subsampling = 8, windowStride = 0.01, normalizer = (s)=>s }) { this.tokenizer = tokenizer; this.encoderSession = encoderSession; this.joinerSession = joinerSession; this.preprocessor = preprocessor; this.ort = ort; // Default IDs – may later be read from model metadata. this.blankId = 1024; // Combined model specific constants this.predHidden = 640; this.predLayers = 2; this.maxTokensPerStep = 10; // Allocate zero LSTM states for the combined decoder; will be reused. const numLayers = this.predLayers; const hidden = this.predHidden; const size = numLayers * 1 * hidden; const z = new Float32Array(size); // zeros this._combState1 = new ort.Tensor('float32', z, [numLayers, 1, hidden]); this._combState2 = new ort.Tensor('float32', z.slice(), [numLayers, 1, hidden]); this._normalizer = normalizer; this.subsampling = subsampling; this.windowStride = windowStride; } /** * Create ParakeetModel by downloading all required assets. * @param {Object} cfg * @param {string} cfg.encoderUrl URL to encoder-model.onnx * @param {string} cfg.decoderUrl URL to decoder_joint-model.onnx * @param {string} cfg.tokenizerUrl URL to vocab.txt or tokens.txt * @param {string} cfg.preprocessorUrl URL to nemo80/128.onnx * @param {('webgpu'|'wasm')} [cfg.backend='webgpu'] */ static async fromUrls(cfg) { const { encoderUrl, decoderUrl, tokenizerUrl, preprocessorUrl, encoderDataUrl, decoderDataUrl, filenames, backend = 'webgpu-hybrid', wasmPaths, subsampling = 8, windowStride = 0.01, verbose = false, enableProfiling = false, enableGraphCapture, cpuThreads = undefined, } = cfg; if (!encoderUrl || !decoderUrl || !tokenizerUrl || !preprocessorUrl) { throw new Error('fromUrls requires encoderUrl, decoderUrl, tokenizerUrl and preprocessorUrl'); } // 1. Init ONNX Runtime let ortBackend = backend; if (backend.startsWith('webgpu')) { ortBackend = 'webgpu'; } const ort = await initOrt({ backend: ortBackend, wasmPaths, numThreads: cpuThreads }); // 2. Configure session options for better performance // Graph-capture is beneficial only when every node runs on the same EP and // ORT can fully record the graph (currently true only for a “strict” // WebGPU session). We therefore enable it *only* when the caller passes // `enableGraphCapture:true` **and** the selected backend is the strict // WebGPU preset. In all other scenarios (hybrid WebGPU or pure WASM) // it is forced off to avoid the “External buffer must be provided …” // runtime error on recent ORT builds. const graphCaptureEnabled = !!enableGraphCapture && backend === 'webgpu-strict'; const isFullWasm = backend === 'wasm'; const baseSessionOptions = { executionProviders: [], graphOptimizationLevel: 'all', executionMode: 'parallel', enableCpuMemArena: true, enableMemPattern: true, enableProfiling, enableGraphCapture: graphCaptureEnabled, logSeverityLevel: verbose ? 0 : 2, // 0=verbose, 2=warning }; // Set execution provider based on backend if (backend === 'webgpu-hybrid') { // WebGPU with fallback to WASM for encoder; decoder may be forced to WASM-only. baseSessionOptions.executionProviders = [ { name: 'webgpu', deviceType: 'gpu', powerPreference: 'high-performance' }, 'wasm' ]; } else if (backend === 'webgpu-strict') { baseSessionOptions.executionProviders = [ { name: 'webgpu', deviceType: 'gpu', powerPreference: 'high-performance' } ]; } else if (backend === 'wasm') { baseSessionOptions.executionProviders = ['wasm']; } console.log(`[Parakeet.js] Creating ONNX sessions with execution mode '${backend}'. Providers:`, baseSessionOptions.executionProviders); if (verbose) { console.log('[Parakeet.js] Verbose logging enabled for ONNX Runtime.'); } // Create separate options for sessions that might have external data const encoderSessionOptions = { ...baseSessionOptions }; if (encoderDataUrl && filenames?.encoder) { encoderSessionOptions.externalData = [{ data: encoderDataUrl, path: filenames.encoder + '.data', }]; } const decoderSessionOptions = { ...baseSessionOptions }; if (decoderDataUrl && filenames?.decoder) { decoderSessionOptions.externalData = [{ data: decoderDataUrl, path: filenames.decoder + '.data', }]; } // In hybrid mode, the decoder is always run on WASM to avoid per-step // stalls. In pure WASM mode, both EPs are WASM anyway. if (backend.startsWith('webgpu')) { // Force decoder to run on WASM decoderSessionOptions.executionProviders = ['wasm']; } // 3. Load tokenizer & preprocessor in parallel with model sessions // helper to create session with graceful fallback if graph capture is unsupported async function createSession(url, opts) { try { return await ort.InferenceSession.create(url, opts); } catch (e) { const msg = (e.message || '') + ''; if (opts.enableGraphCapture && msg.includes('graph capture')) { console.warn('[Parakeet] Graph-capture unsupported for this model/backend; retrying without it'); const retryOpts = { ...opts, enableGraphCapture: false }; return await ort.InferenceSession.create(url, retryOpts); } throw e; } } const tokenizerPromise = ParakeetTokenizer.fromUrl(tokenizerUrl); const preprocPromise = Promise.resolve(new OnnxPreprocessor(preprocessorUrl, { backend, wasmPaths, enableProfiling, enableGraphCapture: isFullWasm ? false : graphCaptureEnabled, numThreads: cpuThreads })); let encoderSession, joinerSession; if (backend === 'webgpu-hybrid') { // avoid parallel create to prevent double initWasm race encoderSession = await createSession(encoderUrl, encoderSessionOptions); joinerSession = await createSession(decoderUrl, decoderSessionOptions); } else { [encoderSession, joinerSession] = await Promise.all([ createSession(encoderUrl, encoderSessionOptions), createSession(decoderUrl, decoderSessionOptions), ]); } const [tokenizer, preprocessor] = await Promise.all([tokenizerPromise, preprocPromise]); return new ParakeetModel({ tokenizer, encoderSession, joinerSession, preprocessor, ort, subsampling, windowStride }); } async _runCombinedStep(encTensor, token, currentState = null) { const singleToken = typeof token === 'number' ? token : this.blankId; const targetTensor = new this.ort.Tensor('int32', new Int32Array([singleToken]), [1, 1]); const lenTensor = new this.ort.Tensor('int32', new Int32Array([1]), [1]); const state1 = currentState?.state1 || this._combState1; const state2 = currentState?.state2 || this._combState2; const feeds = { encoder_outputs: encTensor, targets: targetTensor, target_length: lenTensor, input_states_1: state1, input_states_2: state2, }; const out = await this.joinerSession.run(feeds); const logits = out['outputs']; const vocab = this.tokenizer.id2token.length; const totalDim = logits.dims[3]; const data = logits.data; const tokenLogits = data.slice(0, vocab); const durLogits = data.slice(vocab, totalDim); let step = 0; if (durLogits.length) { let maxVal = -Infinity; for (let i = 0; i < durLogits.length; ++i) if (durLogits[i] > maxVal) { maxVal = durLogits[i]; step = i; } } const newState = { state1: out['output_states_1'] || state1, state2: out['output_states_2'] || state2, }; return { tokenLogits, step, newState }; } async computeFeatures(audio, sampleRate = 16000) { const { features, length } = await this.preprocessor.process(audio); const T = length; // number of frames returned by preprocessor const melBins = features.length / T; return { features, T, melBins }; } /** * Transcribe 16-kHz mono PCM. Returns full rich output (timestamps/confidences opt-in). */ async transcribe(audio, sampleRate = 16000, opts = {}) { const { returnTimestamps = false, returnConfidences = false, temperature = 1.2, debug = false, skipCMVN = false, frameStride = 1, } = opts; const perfEnabled = true; // always collect and log timings let t0, tPreproc = 0, tEncode = 0, tDecode = 0, tToken = 0; if (perfEnabled) t0 = performance.now(); // 1. Feature extraction (ONNX pre-processor) let features, T, melBins; if (perfEnabled) { const s = performance.now(); ({ features, T, melBins } = await this.computeFeatures(audio, sampleRate)); tPreproc = performance.now() - s; } else { ({ features, T, melBins } = await this.computeFeatures(audio, sampleRate)); } // 2. Encode entire utterance const input = new this.ort.Tensor('float32', features, [1, melBins, T]); const lenTensor = new this.ort.Tensor('int64', BigInt64Array.from([BigInt(T)]), [1]); let enc; if (perfEnabled) { const s = performance.now(); const encOut = await this.encoderSession.run({ audio_signal: input, length: lenTensor }); tEncode = performance.now() - s; enc = encOut['outputs'] ?? Object.values(encOut)[0]; } else { const encOut = await this.encoderSession.run({ audio_signal: input, length: lenTensor }); enc = encOut['outputs'] ?? Object.values(encOut)[0]; } // Transpose encoder output [B, D, T] ➔ [T, D] for B=1 const [ , D, Tenc ] = enc.dims; const transposed = new Float32Array(Tenc * D); for (let d = 0; d < D; d++) { for (let t = 0; t < Tenc; t++) { transposed[t * D + d] = enc.data[d * Tenc + t]; } } // --- Decode frame-by-frame ---------------------------------------- const ids = []; const tokenTimes = []; const tokenConfs = []; const frameConfs = []; let overallLogProb = 0; let decoderState = null; let emittedTokens = 0; const decStartTime = perfEnabled ? performance.now() : 0; for (let t = 0; t < Tenc; ) { const frameBuf = transposed.subarray(t * D, (t + 1) * D); const encTensor = new this.ort.Tensor('float32', frameBuf, [1, D, 1]); const prevTok = ids.length ? ids[ids.length - 1] : this.blankId; const { tokenLogits, step, newState } = await this._runCombinedStep(encTensor, prevTok, decoderState); decoderState = newState; // Temperature scaling & argmax let maxVal = -Infinity, maxId = 0; for (let i = 0; i < tokenLogits.length; i++) { const v = tokenLogits[i] / temperature; if (v > maxVal) { maxVal = v; maxId = i; } } let sumExp = 0; for (let i = 0; i < tokenLogits.length; i++) { sumExp += Math.exp((tokenLogits[i] / temperature) - maxVal); } const confVal = 1 / sumExp; frameConfs.push(confVal); overallLogProb += Math.log(confVal); if (maxId !== this.blankId) { ids.push(maxId); if (returnTimestamps) { const TIME_STRIDE = this.subsampling * this.windowStride; const durFrames = step > 0 ? step : 1; const start = t * TIME_STRIDE; const end = (t + durFrames) * TIME_STRIDE; tokenTimes.push([start, end]); } if (returnConfidences) tokenConfs.push(confVal); emittedTokens += 1; } const shouldAdvance = maxId === this.blankId || emittedTokens >= this.maxTokensPerStep; t += step > 0 ? step : (shouldAdvance ? frameStride : 0); if (!shouldAdvance && step === 0) t += 1; // safeguard if (maxId === this.blankId) emittedTokens = 0; } if (perfEnabled) { tDecode = performance.now() - decStartTime; } let tokenStart; if (perfEnabled) tokenStart = performance.now(); const text = this._normalizer(this.tokenizer.decode(ids)); if (perfEnabled) tToken = performance.now() - tokenStart; // Early exit if no extras requested if (!returnTimestamps && !returnConfidences) { if (perfEnabled) { const total = performance.now() - t0; const audioDur = audio.length / sampleRate; const rtf = audioDur / (total / 1000); console.log(`[Perf] RTF: ${rtf.toFixed(2)}x (audio ${audioDur.toFixed(2)} s, time ${(total/1000).toFixed(2)} s)`); console.table({Preprocess:`${tPreproc.toFixed(1)} ms`, Encode:`${tEncode.toFixed(1)} ms`, Decode:`${tDecode.toFixed(1)} ms`, Tokenize:`${tToken.toFixed(1)} ms`, Total:`${total.toFixed(1)} ms`}); } const metrics = perfEnabled ? { preprocess_ms: +tPreproc.toFixed(1), encode_ms: +tEncode.toFixed(1), decode_ms: +tDecode.toFixed(1), tokenize_ms: +tToken.toFixed(1), total_ms: +( (performance.now() - t0).toFixed(1) ), rtf: +((audio.length / sampleRate) / ((performance.now() - t0) / 1000)).toFixed(2) } : null; return { utterance_text: text, words: [], metrics, is_final: true }; } // --- Build words & detailed token arrays --------------------------- const words = []; const tokensDetailed = []; let currentWord = '', wordStart = 0, wordEnd = 0; let wordConfs = []; ids.forEach((tokId, i) => { const raw = this.tokenizer.id2token[tokId]; if (raw === this.tokenizer.blankToken) return; const isWordStart = raw.startsWith('▁'); const cleanTok = isWordStart ? raw.slice(1) : raw; const ts = tokenTimes[i] || [null, null]; const conf = tokenConfs[i]; // tokensDetailed entry const tokEntry = { token: [cleanTok] }; if (returnTimestamps) { tokEntry.start_time = +ts[0].toFixed(3); tokEntry.end_time = +ts[1].toFixed(3); } if (returnConfidences) tokEntry.confidence = +conf.toFixed(4); tokensDetailed.push(tokEntry); // accumulate into words if (isWordStart) { if (currentWord) { const avg = wordConfs.length ? wordConfs.reduce((a,b)=>a+b,0)/wordConfs.length : 0; words.push({ text: currentWord, start_time: +wordStart.toFixed(3), end_time: +wordEnd.toFixed(3), confidence: +avg.toFixed(4) }); } currentWord = cleanTok; if (returnTimestamps) { wordStart = ts[0]; wordEnd = ts[1]; } wordConfs = returnConfidences ? [conf] : []; } else { currentWord += cleanTok; if (returnTimestamps) wordEnd = ts[1]; if (returnConfidences) wordConfs.push(conf); } }); if (currentWord) { const avg = wordConfs.length ? wordConfs.reduce((a,b)=>a+b,0)/wordConfs.length : 0; words.push({ text: currentWord, start_time: +wordStart.toFixed(3), end_time: +wordEnd.toFixed(3), confidence: +avg.toFixed(4) }); } const avgWordConf = words.length && returnConfidences ? words.reduce((a,b)=>a+b.confidence,0)/words.length : null; const avgTokenConf = tokensDetailed.length && returnConfidences ? tokensDetailed.reduce((a,b)=>a+(b.confidence||0),0)/tokensDetailed.length : null; if (perfEnabled) { const total = performance.now() - t0; const audioDur = audio.length / sampleRate; const rtf = audioDur / (total / 1000); console.log(`[Perf] RTF: ${rtf.toFixed(2)}x (audio ${audioDur.toFixed(2)} s, time ${(total/1000).toFixed(2)} s)`); console.table({Preprocess:`${tPreproc.toFixed(1)} ms`, Encode:`${tEncode.toFixed(1)} ms`, Decode:`${tDecode.toFixed(1)} ms`, Tokenize:`${tToken.toFixed(1)} ms`, Total:`${total.toFixed(1)} ms`}); } return { utterance_text: text, words, tokens: tokensDetailed, confidence_scores: returnConfidences ? { token: tokenConfs.map(c=>+c.toFixed(4)), token_avg: +avgTokenConf?.toFixed(4), word: words.map(w=>w.confidence), word_avg: +avgWordConf?.toFixed(4), frame: frameConfs.map(f=>+f.toFixed(4)), frame_avg: frameConfs.length ? +(frameConfs.reduce((a,b)=>a+b,0)/frameConfs.length).toFixed(4) : null, overall_log_prob: +overallLogProb.toFixed(6) } : { overall_log_prob: null, frame: null, frame_avg: null }, metrics: perfEnabled ? { preprocess_ms: +tPreproc.toFixed(1), encode_ms: +tEncode.toFixed(1), decode_ms: +tDecode.toFixed(1), tokenize_ms: +tToken.toFixed(1), total_ms: +( (performance.now() - t0).toFixed(1) ), rtf: +((audio.length / sampleRate) / ((performance.now() - t0) / 1000)).toFixed(2) } : null, is_final: true, }; } /** * Stop ORT profiling (if enabled) for all sessions and print a quick summary * of time spent on GPU (WebGPU) vs CPU (WASM) kernels. Returns the parsed * summary object for further inspection. */ endProfiling() { try { this.encoderSession?.endProfiling(); } catch(e) { /* ignore */ } try { this.joinerSession?.endProfiling(); } catch(e) { /* ignore */ } const FS = this.ort?.env?.wasm?.FS; if (!FS) { console.warn('[Parakeet] Profiling FS not accessible'); return null; } const files = FS.readdir('/tmp').filter(f => f.startsWith('profile_') && f.endsWith('.json')); if (!files.length) { console.warn('[Parakeet] No profiling files found. Was profiling enabled?'); return null; } const summary = {}; for (const file of files) { try { const txt = FS.readFile('/tmp/' + file, { encoding: 'utf8' }); const events = JSON.parse(txt); let gpu = 0, cpu = 0; for (const ev of events) { if (ev.cat === 'Node') { const prov = ev.args?.provider; if (prov === 'webgpu') gpu += ev.dur; else if (prov) cpu += ev.dur; } } summary[file] = { gpu_us: gpu, cpu_us: cpu, total_us: gpu + cpu }; } catch (err) { console.warn('[Parakeet] Failed to parse profile file', file, err); } } console.table(summary); return summary; } }