@jaehyun-ko/speaker-verification
Version:
Real-time speaker verification in the browser using NeXt-TDNN models
179 lines (178 loc) • 7.42 kB
JavaScript
;
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
var desc = Object.getOwnPropertyDescriptor(m, k);
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
desc = { enumerable: true, get: function() { return m[k]; } };
}
Object.defineProperty(o, k2, desc);
}) : (function(o, m, k, k2) {
if (k2 === undefined) k2 = k;
o[k2] = m[k];
}));
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
Object.defineProperty(o, "default", { enumerable: true, value: v });
}) : function(o, v) {
o["default"] = v;
});
var __importStar = (this && this.__importStar) || (function () {
var ownKeys = function(o) {
ownKeys = Object.getOwnPropertyNames || function (o) {
var ar = [];
for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
return ar;
};
return ownKeys(o);
};
return function (mod) {
if (mod && mod.__esModule) return mod;
var result = {};
if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
__setModuleDefault(result, mod);
return result;
};
})();
Object.defineProperty(exports, "__esModule", { value: true });
exports.NeXtTDNNModel = void 0;
const ort = __importStar(require("onnxruntime-web"));
class NeXtTDNNModel {
constructor(config) {
this.session = null;
this.config = config;
}
async initialize() {
try {
// Configure ONNX Runtime
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
ort.env.wasm.simd = true;
// Set execution providers
const executionProviders = this.config.executionProviders || ['wasm'];
// Create inference session based on the type of modelPath
const sessionOptions = {
executionProviders,
graphOptimizationLevel: 'all'
};
// Determine the model source (modelPath or modelData)
const modelPath = this.config.modelPath;
const modelData = this.config.modelData;
if (!modelPath && !modelData) {
throw new Error('Either modelPath or modelData must be provided');
}
// Create session with the appropriate method based on input type
if (modelPath && typeof modelPath === 'string') {
// Load from URL/path
this.session = await ort.InferenceSession.create(modelPath, sessionOptions);
}
else if (modelData || (modelPath && typeof modelPath !== 'string')) {
// Load from ArrayBuffer or Uint8Array
const data = modelData || modelPath;
this.session = await ort.InferenceSession.create(data, sessionOptions);
}
else {
throw new Error('Invalid model source');
}
}
catch (error) {
throw new Error(`Failed to load model: ${error}`);
}
}
async infer(melSpectrogram, numFrames) {
if (!this.session) {
throw new Error('Model not initialized. Call initialize() first.');
}
try {
// Get input metadata
const inputName = this.session.inputNames[0];
const outputName = this.session.outputNames[0];
// Reshape input for model [batch_size, mel_bins, time_frames]
const nMels = 80; // Fixed for our model
const inputTensor = new ort.Tensor('float32', melSpectrogram, [1, nMels, numFrames]);
// Run inference
const startTime = performance.now();
const results = await this.session.run({ [inputName]: inputTensor });
const inferenceTime = performance.now() - startTime;
// Extract embeddings
const output = results[outputName];
const rawOutput = output.data;
// Additional debugging
if (rawOutput.length === 0) {
throw new Error('Model output is empty');
}
// Check for NaN values in raw output
const hasNaN = Array.from(rawOutput).some(v => isNaN(v));
if (hasNaN) {
throw new Error('Model output contains NaN values');
}
// The model outputs [batch_size, hidden_dim, time_frames] or [batch_size, hidden_dim]
// We need to check the dimensions
let embeddings;
if (output.dims.length === 2) {
// Output is already [batch_size, hidden_dim], no pooling needed
const [batchSize, hiddenDim] = output.dims;
embeddings = new Float32Array(rawOutput); // Copy the output directly
}
else if (output.dims.length === 3) {
// Output is [batch_size, hidden_dim, time_frames], need mean pooling
const [batchSize, hiddenDim, timeFrames] = output.dims;
embeddings = new Float32Array(hiddenDim);
for (let h = 0; h < hiddenDim; h++) {
let sum = 0;
for (let t = 0; t < timeFrames; t++) {
// Correct indexing for [batch=1, hidden_dim, time_frames] layout
const index = h * timeFrames + t;
sum += rawOutput[index];
}
embeddings[h] = sum / timeFrames;
}
}
else {
throw new Error(`Unexpected output dimensions: ${output.dims}`);
}
// L2 normalize the embeddings
let norm = 0;
for (let i = 0; i < embeddings.length; i++) {
norm += embeddings[i] * embeddings[i];
}
norm = Math.sqrt(norm);
// Prevent division by zero
if (norm === 0 || isNaN(norm)) {
norm = 1; // Prevent division by zero
}
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = embeddings[i] / norm;
}
// Check L2 norm after normalization
let normCheck = 0;
for (let i = 0; i < embeddings.length; i++) {
normCheck += embeddings[i] * embeddings[i];
}
return {
embedding: embeddings,
timestamp: Date.now()
};
}
catch (error) {
throw new Error(`Inference failed: ${error}`);
}
}
async cleanup() {
if (this.session) {
await this.session.release();
this.session = null;
}
}
// Get model metadata
getModelInfo() {
if (!this.session)
return null;
const inputName = this.session.inputNames[0];
const outputName = this.session.outputNames[0];
// Note: ONNX Runtime Web doesn't provide direct access to shapes
// These are based on our model's expected shapes
return {
inputShape: [1, 80, -1], // -1 for dynamic time dimension
outputShape: [1, 192]
};
}
}
exports.NeXtTDNNModel = NeXtTDNNModel;