UNPKG

web-asr-core

Version:

WebASR Core - Browser-based speech processing with VAD, WakeWord and Whisper - Unified all-in-one version

330 lines (328 loc) 12.3 kB
"use strict"; (() => { var __defProp = Object.defineProperty; var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value); // src/workers/onnx-inference.worker.ts console.log("[Worker] Starting initialization..."); var isModuleWorker = (() => { try { if (typeof importScripts === "function") { importScripts("data:text/javascript,"); return false; } return true; } catch (e) { return String(e).includes("Module scripts don't support importScripts"); } })(); console.log("[Worker] Worker type check:", { isModuleWorker, hasImportScripts: typeof importScripts === "function", hasWebGPU: !!self.navigator?.gpu, workerType: isModuleWorker ? "MODULE" : "CLASSIC" }); try { importScripts("../../node_modules/onnxruntime-web/dist/ort.min.js"); console.log("[Worker] ONNX Runtime loaded from node_modules"); } catch (e) { console.log("[Worker] Loading ONNX Runtime from CDN (node_modules not available)"); importScripts("https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"); } var WorkerEventBridge = class { /** * 發送事件到主執行緒 */ static emitEvent(eventType, detail) { self.postMessage({ type: "event", event: eventType, detail, timestamp: Date.now() }); } /** * 發送處理錯誤事件 */ static emitError(error, context) { this.emitEvent("processing-error", { error: { message: error.message, stack: error.stack }, context: `worker:${context}` }); } }; var ONNXInferenceWorker = class { constructor() { __publicField(this, "sessions", /* @__PURE__ */ new Map()); __publicField(this, "vadStates", /* @__PURE__ */ new Map()); // 保存 VAD LSTM 狀態 __publicField(this, "vadActiveStates", /* @__PURE__ */ new Map()); // 保存 VAD 活動狀態 __publicField(this, "isWebGPUAvailable", false); this.initialize(); } isVadActive(sessionKey) { return this.vadActiveStates.get(sessionKey) || false; } setVadActive(sessionKey, active) { this.vadActiveStates.set(sessionKey, active); } async initialize() { try { const hasGPU = !!self.navigator?.gpu; if (hasGPU) { const isWin2 = self.navigator.userAgent?.includes("Windows"); const opts = isWin2 ? {} : { powerPreference: "high-performance" }; const adapter = await self.navigator.gpu.requestAdapter(opts); if (adapter) { this.isWebGPUAvailable = true; console.log("[ONNX Worker] WebGPU is available:", adapter?.name || "adapter"); } } } catch (error) { console.log("[ONNX Worker] WebGPU not available:", error); } ort.env.wasm.simd = true; ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4; const isWin = self.navigator.userAgent?.includes("Windows"); if (!isWin) { ort.env.webgpu.powerPreference = "high-performance"; } try { const testFetch = await fetch("../../node_modules/onnxruntime-web/dist/ort-wasm.wasm", { method: "HEAD" }); if (testFetch.ok) { ort.env.wasm.wasmPaths = "../../node_modules/onnxruntime-web/dist/"; console.log("[ONNX Worker] Using local WASM files from node_modules"); } else { throw new Error("Local WASM not available"); } } catch (e) { ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/"; console.log("[ONNX Worker] Using CDN for WASM files"); } self.postMessage({ type: "initialized", data: { webgpuAvailable: this.isWebGPUAvailable } }); } async loadModel(modelName, config) { const cacheKey = `${modelName}_${JSON.stringify(config.executionProviders)}`; if (this.sessions.has(cacheKey)) { return this.sessions.get(cacheKey); } console.log(`[ONNX Worker] Loading model: ${modelName}`); const executionProviders = []; console.log(`[ONNX Worker] WebGPU available: ${this.isWebGPUAvailable}, Requested providers:`, config.executionProviders); for (const provider of config.executionProviders) { if (provider === "webgpu" && this.isWebGPUAvailable) { console.log("[ONNX Worker] Adding WebGPU provider"); executionProviders.push({ name: "webgpu", ...config.webgpuOptions }); } else if (provider === "wasm") { console.log("[ONNX Worker] Adding WASM provider"); executionProviders.push({ name: "wasm", ...config.wasmOptions }); } } if (executionProviders.length === 0) { executionProviders.push({ name: "wasm" }); } try { let session; let modelUrl = config.modelPath; if (modelUrl.startsWith("models/") || modelUrl.startsWith("./models/")) { modelUrl = modelUrl.replace(/^\.\//, ""); modelUrl = `/../../${modelUrl}`; } try { session = await ort.InferenceSession.create( modelUrl, { executionProviders } ); } catch (pathError) { if (pathError.message?.includes("external data file")) { console.log(`[ONNX Worker] Path loading failed, trying ArrayBuffer for ${modelName}`); const response = await fetch(modelUrl); if (!response.ok) { throw new Error(`Failed to fetch model: ${response.status}`); } const modelBuffer = await response.arrayBuffer(); session = await ort.InferenceSession.create( modelBuffer, // Type assertion needed as ONNX Runtime types may be incomplete { executionProviders } ); } else { throw pathError; } } this.sessions.set(cacheKey, session); console.log(`[ONNX Worker] Model loaded successfully: ${modelName}`); return session; } catch (error) { console.error(`[ONNX Worker] Failed to load model:`, error); throw error; } } async runVADInference(session, inputData, sessionKey = "default") { let stateData = this.vadStates.get(sessionKey); const wasActive = stateData ? this.isVadActive(sessionKey) : false; if (!stateData) { stateData = new Float32Array(2 * 1 * 128); this.vadStates.set(sessionKey, stateData); } const state = new ort.Tensor("float32", stateData, [2, 1, 128]); const sr = new ort.Tensor("int64", BigInt64Array.from([16000n]), [1]); const feeds = { "input": new ort.Tensor("float32", inputData, [1, inputData.length]), "state": state, "sr": sr }; const results = await session.run(feeds); const newState = results["state_out"] || results["stateN"] || results["state"]; if (newState && newState.data) { const newStateData = new Float32Array(newState.data); this.vadStates.set(sessionKey, newStateData); } const output = results["output"] || results["21"] || results[Object.keys(results)[0]]; if (!output || !output.data) { console.error("[ONNX Worker] No valid output from VAD model"); return { isSpeech: false, probability: 0 }; } const probability = output.data[0]; const threshold = 0.15; const isSpeech = probability > threshold; if (!wasActive && isSpeech) { WorkerEventBridge.emitEvent("speech-start", { timestamp: Date.now(), probability }); } else if (wasActive && !isSpeech) { WorkerEventBridge.emitEvent("speech-end", { timestamp: Date.now(), probability }); } this.setVadActive(sessionKey, isSpeech); const maxVal = Math.max(...inputData); const minVal = Math.min(...inputData); const avgVal = inputData.reduce((a, b) => a + Math.abs(b), 0) / inputData.length; if (avgVal > 1e-3) { } return { isSpeech, probability }; } async runWakeWordInference(session, inputData, modelName = "unknown") { let feeds; const embeddingDim = 96; const isEmbeddingTensor = inputData.length % embeddingDim === 0 && inputData.length >= embeddingDim; if (isEmbeddingTensor) { const timeSteps = inputData.length / embeddingDim; console.log(`[ONNX Worker] Detected embedding tensor: T=${timeSteps}, D=${embeddingDim}`); feeds = { "input": new ort.Tensor("float32", inputData, [1, timeSteps, embeddingDim]) }; } else { console.log(`[ONNX Worker] Detected raw audio tensor: length=${inputData.length}`); feeds = { "input": new ort.Tensor("float32", inputData, [1, 1, inputData.length]) }; } const results = await session.run(feeds); const output = results["output"] || results[Object.keys(results)[0]]; const scores = Array.from(output.data); let maxScore = 0; let detectedWord = null; scores.forEach((score, index) => { if (score > maxScore) { maxScore = score; detectedWord = index; } }); const isDetected = maxScore > 0.5; if (isDetected) { WorkerEventBridge.emitEvent("wakeword-detected", { word: modelName, wordIndex: detectedWord, confidence: maxScore, timestamp: Date.now() }); } return { detected: isDetected, confidence: maxScore, wordIndex: detectedWord }; } async processInference(request) { const startTime = performance.now(); try { const session = await this.loadModel(request.modelName, request.config); let result; if (request.type === "vad") { result = await this.runVADInference(session, request.inputData, request.id); } else { result = await this.runWakeWordInference(session, request.inputData, request.modelName); } const executionTime = performance.now() - startTime; return { id: request.id, type: request.type, result, executionTime, provider: "unknown" // ONNX Runtime Web 不公開 provider 資訊 }; } catch (error) { console.error(`[ONNX Worker] Inference failed for ${request.type}:`, error); WorkerEventBridge.emitError(error, `processInference-${request.type}`); return { id: request.id, type: request.type, result: null, error: error instanceof Error ? error.message : String(error), executionTime: performance.now() - startTime }; } } async preloadModel(modelName, config) { await this.loadModel(modelName, config); console.log(`[ONNX Worker] Preloaded model: ${modelName}`); } clearCache() { this.sessions.clear(); this.vadStates.clear(); this.vadActiveStates.clear(); console.log("[ONNX Worker] Model cache and states cleared"); } }; var worker = new ONNXInferenceWorker(); self.addEventListener("message", async (event) => { const { type, data } = event.data; switch (type) { case "inference": const response = await worker.processInference(data); self.postMessage({ type: "inference-result", data: response }); break; case "preload": await worker.preloadModel(data.modelName, data.config); self.postMessage({ type: "preload-complete", data: { modelName: data.modelName } }); break; case "clear-cache": worker.clearCache(); self.postMessage({ type: "cache-cleared", data: {} }); break; default: console.warn(`[ONNX Worker] Unknown message type: ${type}`); } }); })();