web-asr-core
Version:
WebASR Core - Browser-based speech processing with VAD, WakeWord and Whisper - Unified all-in-one version
330 lines (328 loc) • 12.3 kB
JavaScript
() => {
var __defProp = Object.defineProperty;
var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;
var __publicField = (obj, key, value) => __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value);
// src/workers/onnx-inference.worker.ts
console.log("[Worker] Starting initialization...");
var isModuleWorker = (() => {
try {
if (typeof importScripts === "function") {
importScripts("data:text/javascript,");
return false;
}
return true;
} catch (e) {
return String(e).includes("Module scripts don't support importScripts");
}
})();
console.log("[Worker] Worker type check:", {
isModuleWorker,
hasImportScripts: typeof importScripts === "function",
hasWebGPU: !!self.navigator?.gpu,
workerType: isModuleWorker ? "MODULE" : "CLASSIC"
});
try {
importScripts("../../node_modules/onnxruntime-web/dist/ort.min.js");
console.log("[Worker] ONNX Runtime loaded from node_modules");
} catch (e) {
console.log("[Worker] Loading ONNX Runtime from CDN (node_modules not available)");
importScripts("https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js");
}
var WorkerEventBridge = class {
/**
* 發送事件到主執行緒
*/
static emitEvent(eventType, detail) {
self.postMessage({
type: "event",
event: eventType,
detail,
timestamp: Date.now()
});
}
/**
* 發送處理錯誤事件
*/
static emitError(error, context) {
this.emitEvent("processing-error", {
error: {
message: error.message,
stack: error.stack
},
context: `worker:${context}`
});
}
};
var ONNXInferenceWorker = class {
constructor() {
__publicField(this, "sessions", /* @__PURE__ */ new Map());
__publicField(this, "vadStates", /* @__PURE__ */ new Map());
// 保存 VAD LSTM 狀態
__publicField(this, "vadActiveStates", /* @__PURE__ */ new Map());
// 保存 VAD 活動狀態
__publicField(this, "isWebGPUAvailable", false);
this.initialize();
}
isVadActive(sessionKey) {
return this.vadActiveStates.get(sessionKey) || false;
}
setVadActive(sessionKey, active) {
this.vadActiveStates.set(sessionKey, active);
}
async initialize() {
try {
const hasGPU = !!self.navigator?.gpu;
if (hasGPU) {
const isWin2 = self.navigator.userAgent?.includes("Windows");
const opts = isWin2 ? {} : { powerPreference: "high-performance" };
const adapter = await self.navigator.gpu.requestAdapter(opts);
if (adapter) {
this.isWebGPUAvailable = true;
console.log("[ONNX Worker] WebGPU is available:", adapter?.name || "adapter");
}
}
} catch (error) {
console.log("[ONNX Worker] WebGPU not available:", error);
}
ort.env.wasm.simd = true;
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
const isWin = self.navigator.userAgent?.includes("Windows");
if (!isWin) {
ort.env.webgpu.powerPreference = "high-performance";
}
try {
const testFetch = await fetch("../../node_modules/onnxruntime-web/dist/ort-wasm.wasm", { method: "HEAD" });
if (testFetch.ok) {
ort.env.wasm.wasmPaths = "../../node_modules/onnxruntime-web/dist/";
console.log("[ONNX Worker] Using local WASM files from node_modules");
} else {
throw new Error("Local WASM not available");
}
} catch (e) {
ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/";
console.log("[ONNX Worker] Using CDN for WASM files");
}
self.postMessage({
type: "initialized",
data: { webgpuAvailable: this.isWebGPUAvailable }
});
}
async loadModel(modelName, config) {
const cacheKey = `${modelName}_${JSON.stringify(config.executionProviders)}`;
if (this.sessions.has(cacheKey)) {
return this.sessions.get(cacheKey);
}
console.log(`[ONNX Worker] Loading model: ${modelName}`);
const executionProviders = [];
console.log(`[ONNX Worker] WebGPU available: ${this.isWebGPUAvailable}, Requested providers:`, config.executionProviders);
for (const provider of config.executionProviders) {
if (provider === "webgpu" && this.isWebGPUAvailable) {
console.log("[ONNX Worker] Adding WebGPU provider");
executionProviders.push({
name: "webgpu",
...config.webgpuOptions
});
} else if (provider === "wasm") {
console.log("[ONNX Worker] Adding WASM provider");
executionProviders.push({
name: "wasm",
...config.wasmOptions
});
}
}
if (executionProviders.length === 0) {
executionProviders.push({ name: "wasm" });
}
try {
let session;
let modelUrl = config.modelPath;
if (modelUrl.startsWith("models/") || modelUrl.startsWith("./models/")) {
modelUrl = modelUrl.replace(/^\.\//, "");
modelUrl = `/../../${modelUrl}`;
}
try {
session = await ort.InferenceSession.create(
modelUrl,
{ executionProviders }
);
} catch (pathError) {
if (pathError.message?.includes("external data file")) {
console.log(`[ONNX Worker] Path loading failed, trying ArrayBuffer for ${modelName}`);
const response = await fetch(modelUrl);
if (!response.ok) {
throw new Error(`Failed to fetch model: ${response.status}`);
}
const modelBuffer = await response.arrayBuffer();
session = await ort.InferenceSession.create(
modelBuffer,
// Type assertion needed as ONNX Runtime types may be incomplete
{ executionProviders }
);
} else {
throw pathError;
}
}
this.sessions.set(cacheKey, session);
console.log(`[ONNX Worker] Model loaded successfully: ${modelName}`);
return session;
} catch (error) {
console.error(`[ONNX Worker] Failed to load model:`, error);
throw error;
}
}
async runVADInference(session, inputData, sessionKey = "default") {
let stateData = this.vadStates.get(sessionKey);
const wasActive = stateData ? this.isVadActive(sessionKey) : false;
if (!stateData) {
stateData = new Float32Array(2 * 1 * 128);
this.vadStates.set(sessionKey, stateData);
}
const state = new ort.Tensor("float32", stateData, [2, 1, 128]);
const sr = new ort.Tensor("int64", BigInt64Array.from([16000n]), [1]);
const feeds = {
"input": new ort.Tensor("float32", inputData, [1, inputData.length]),
"state": state,
"sr": sr
};
const results = await session.run(feeds);
const newState = results["state_out"] || results["stateN"] || results["state"];
if (newState && newState.data) {
const newStateData = new Float32Array(newState.data);
this.vadStates.set(sessionKey, newStateData);
}
const output = results["output"] || results["21"] || results[Object.keys(results)[0]];
if (!output || !output.data) {
console.error("[ONNX Worker] No valid output from VAD model");
return { isSpeech: false, probability: 0 };
}
const probability = output.data[0];
const threshold = 0.15;
const isSpeech = probability > threshold;
if (!wasActive && isSpeech) {
WorkerEventBridge.emitEvent("speech-start", {
timestamp: Date.now(),
probability
});
} else if (wasActive && !isSpeech) {
WorkerEventBridge.emitEvent("speech-end", {
timestamp: Date.now(),
probability
});
}
this.setVadActive(sessionKey, isSpeech);
const maxVal = Math.max(...inputData);
const minVal = Math.min(...inputData);
const avgVal = inputData.reduce((a, b) => a + Math.abs(b), 0) / inputData.length;
if (avgVal > 1e-3) {
}
return {
isSpeech,
probability
};
}
async runWakeWordInference(session, inputData, modelName = "unknown") {
let feeds;
const embeddingDim = 96;
const isEmbeddingTensor = inputData.length % embeddingDim === 0 && inputData.length >= embeddingDim;
if (isEmbeddingTensor) {
const timeSteps = inputData.length / embeddingDim;
console.log(`[ONNX Worker] Detected embedding tensor: T=${timeSteps}, D=${embeddingDim}`);
feeds = {
"input": new ort.Tensor("float32", inputData, [1, timeSteps, embeddingDim])
};
} else {
console.log(`[ONNX Worker] Detected raw audio tensor: length=${inputData.length}`);
feeds = {
"input": new ort.Tensor("float32", inputData, [1, 1, inputData.length])
};
}
const results = await session.run(feeds);
const output = results["output"] || results[Object.keys(results)[0]];
const scores = Array.from(output.data);
let maxScore = 0;
let detectedWord = null;
scores.forEach((score, index) => {
if (score > maxScore) {
maxScore = score;
detectedWord = index;
}
});
const isDetected = maxScore > 0.5;
if (isDetected) {
WorkerEventBridge.emitEvent("wakeword-detected", {
word: modelName,
wordIndex: detectedWord,
confidence: maxScore,
timestamp: Date.now()
});
}
return {
detected: isDetected,
confidence: maxScore,
wordIndex: detectedWord
};
}
async processInference(request) {
const startTime = performance.now();
try {
const session = await this.loadModel(request.modelName, request.config);
let result;
if (request.type === "vad") {
result = await this.runVADInference(session, request.inputData, request.id);
} else {
result = await this.runWakeWordInference(session, request.inputData, request.modelName);
}
const executionTime = performance.now() - startTime;
return {
id: request.id,
type: request.type,
result,
executionTime,
provider: "unknown"
// ONNX Runtime Web 不公開 provider 資訊
};
} catch (error) {
console.error(`[ONNX Worker] Inference failed for ${request.type}:`, error);
WorkerEventBridge.emitError(error, `processInference-${request.type}`);
return {
id: request.id,
type: request.type,
result: null,
error: error instanceof Error ? error.message : String(error),
executionTime: performance.now() - startTime
};
}
}
async preloadModel(modelName, config) {
await this.loadModel(modelName, config);
console.log(`[ONNX Worker] Preloaded model: ${modelName}`);
}
clearCache() {
this.sessions.clear();
this.vadStates.clear();
this.vadActiveStates.clear();
console.log("[ONNX Worker] Model cache and states cleared");
}
};
var worker = new ONNXInferenceWorker();
self.addEventListener("message", async (event) => {
const { type, data } = event.data;
switch (type) {
case "inference":
const response = await worker.processInference(data);
self.postMessage({ type: "inference-result", data: response });
break;
case "preload":
await worker.preloadModel(data.modelName, data.config);
self.postMessage({ type: "preload-complete", data: { modelName: data.modelName } });
break;
case "clear-cache":
worker.clearCache();
self.postMessage({ type: "cache-cleared", data: {} });
break;
default:
console.warn(`[ONNX Worker] Unknown message type: ${type}`);
}
});
})();
;
(