UNPKG

@jaehyun-ko/speaker-verification

Version:

Real-time speaker verification in the browser using NeXt-TDNN models

179 lines (178 loc) 7.42 kB
"use strict"; var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) { if (k2 === undefined) k2 = k; var desc = Object.getOwnPropertyDescriptor(m, k); if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) { desc = { enumerable: true, get: function() { return m[k]; } }; } Object.defineProperty(o, k2, desc); }) : (function(o, m, k, k2) { if (k2 === undefined) k2 = k; o[k2] = m[k]; })); var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) { Object.defineProperty(o, "default", { enumerable: true, value: v }); }) : function(o, v) { o["default"] = v; }); var __importStar = (this && this.__importStar) || (function () { var ownKeys = function(o) { ownKeys = Object.getOwnPropertyNames || function (o) { var ar = []; for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k; return ar; }; return ownKeys(o); }; return function (mod) { if (mod && mod.__esModule) return mod; var result = {}; if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]); __setModuleDefault(result, mod); return result; }; })(); Object.defineProperty(exports, "__esModule", { value: true }); exports.NeXtTDNNModel = void 0; const ort = __importStar(require("onnxruntime-web")); class NeXtTDNNModel { constructor(config) { this.session = null; this.config = config; } async initialize() { try { // Configure ONNX Runtime ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4; ort.env.wasm.simd = true; // Set execution providers const executionProviders = this.config.executionProviders || ['wasm']; // Create inference session based on the type of modelPath const sessionOptions = { executionProviders, graphOptimizationLevel: 'all' }; // Determine the model source (modelPath or modelData) const modelPath = this.config.modelPath; const modelData = this.config.modelData; if (!modelPath && !modelData) { throw new Error('Either modelPath or modelData must be provided'); } // Create session with the appropriate method based on input type if (modelPath && typeof modelPath === 'string') { // Load from URL/path this.session = await ort.InferenceSession.create(modelPath, sessionOptions); } else if (modelData || (modelPath && typeof modelPath !== 'string')) { // Load from ArrayBuffer or Uint8Array const data = modelData || modelPath; this.session = await ort.InferenceSession.create(data, sessionOptions); } else { throw new Error('Invalid model source'); } } catch (error) { throw new Error(`Failed to load model: ${error}`); } } async infer(melSpectrogram, numFrames) { if (!this.session) { throw new Error('Model not initialized. Call initialize() first.'); } try { // Get input metadata const inputName = this.session.inputNames[0]; const outputName = this.session.outputNames[0]; // Reshape input for model [batch_size, mel_bins, time_frames] const nMels = 80; // Fixed for our model const inputTensor = new ort.Tensor('float32', melSpectrogram, [1, nMels, numFrames]); // Run inference const startTime = performance.now(); const results = await this.session.run({ [inputName]: inputTensor }); const inferenceTime = performance.now() - startTime; // Extract embeddings const output = results[outputName]; const rawOutput = output.data; // Additional debugging if (rawOutput.length === 0) { throw new Error('Model output is empty'); } // Check for NaN values in raw output const hasNaN = Array.from(rawOutput).some(v => isNaN(v)); if (hasNaN) { throw new Error('Model output contains NaN values'); } // The model outputs [batch_size, hidden_dim, time_frames] or [batch_size, hidden_dim] // We need to check the dimensions let embeddings; if (output.dims.length === 2) { // Output is already [batch_size, hidden_dim], no pooling needed const [batchSize, hiddenDim] = output.dims; embeddings = new Float32Array(rawOutput); // Copy the output directly } else if (output.dims.length === 3) { // Output is [batch_size, hidden_dim, time_frames], need mean pooling const [batchSize, hiddenDim, timeFrames] = output.dims; embeddings = new Float32Array(hiddenDim); for (let h = 0; h < hiddenDim; h++) { let sum = 0; for (let t = 0; t < timeFrames; t++) { // Correct indexing for [batch=1, hidden_dim, time_frames] layout const index = h * timeFrames + t; sum += rawOutput[index]; } embeddings[h] = sum / timeFrames; } } else { throw new Error(`Unexpected output dimensions: ${output.dims}`); } // L2 normalize the embeddings let norm = 0; for (let i = 0; i < embeddings.length; i++) { norm += embeddings[i] * embeddings[i]; } norm = Math.sqrt(norm); // Prevent division by zero if (norm === 0 || isNaN(norm)) { norm = 1; // Prevent division by zero } for (let i = 0; i < embeddings.length; i++) { embeddings[i] = embeddings[i] / norm; } // Check L2 norm after normalization let normCheck = 0; for (let i = 0; i < embeddings.length; i++) { normCheck += embeddings[i] * embeddings[i]; } return { embedding: embeddings, timestamp: Date.now() }; } catch (error) { throw new Error(`Inference failed: ${error}`); } } async cleanup() { if (this.session) { await this.session.release(); this.session = null; } } // Get model metadata getModelInfo() { if (!this.session) return null; const inputName = this.session.inputNames[0]; const outputName = this.session.outputNames[0]; // Note: ONNX Runtime Web doesn't provide direct access to shapes // These are based on our model's expected shapes return { inputShape: [1, 80, -1], // -1 for dynamic time dimension outputShape: [1, 192] }; } } exports.NeXtTDNNModel = NeXtTDNNModel;