UNPKG

@astermind/astermind-premium

Version:

Astermind Premium - Premium ML Toolkit

145 lines 5.72 kB
// variational-elm.ts — Variational ELM with uncertainty estimation // Probabilistic ELM with Bayesian inference and confidence intervals import { ELM } from '@astermind/astermind-elm'; import { requireLicense } from '../core/license.js'; /** * Variational ELM with uncertainty estimation * Features: * - Probabilistic predictions with uncertainty * - Bayesian inference * - Confidence intervals * - Robust predictions with uncertainty quantification */ export class VariationalELM { constructor(options) { this.weightSamples = []; // Sampled weight matrices this.trained = false; requireLicense(); // Premium feature - requires valid license this.categories = options.categories; this.options = { categories: options.categories, hiddenUnits: options.hiddenUnits ?? 256, priorVariance: options.priorVariance ?? 1.0, posteriorSamples: options.posteriorSamples ?? 10, activation: options.activation ?? 'relu', maxLen: options.maxLen ?? 100, useTokenizer: options.useTokenizer ?? true, }; this.elm = new ELM({ useTokenizer: this.options.useTokenizer ? true : undefined, hiddenUnits: this.options.hiddenUnits, categories: this.options.categories, maxLen: this.options.maxLen, activation: this.options.activation, }); } /** * Train variational ELM */ train(X, y) { // Prepare labels const labelIndices = y.map(label => typeof label === 'number' ? label : this.options.categories.indexOf(label)); // Train base ELM this.elm.setCategories?.(this.options.categories); this.elm.trainFromData?.(X, labelIndices); // Sample weights for uncertainty estimation this._sampleWeights(); this.trained = true; } /** * Predict with uncertainty estimation */ predict(X, topK = 3, includeUncertainty = true) { if (!this.trained) { throw new Error('Model must be trained before prediction'); } const XArray = Array.isArray(X[0]) ? X : [X]; const allResults = []; for (const x of XArray) { // Get base prediction const basePreds = this.elm.predictFromVector?.([x], topK) || []; // Estimate uncertainty const uncertainty = includeUncertainty ? this._estimateUncertainty(x) : 0.5; for (const pred of basePreds.slice(0, topK)) { const prob = pred.prob || 0; const confidence = Math.max(0, Math.min(1, 1 - uncertainty)); // Compute confidence interval const stdDev = Math.sqrt(uncertainty * prob * (1 - prob)); const confidenceInterval = [ Math.max(0, prob - 1.96 * stdDev), Math.min(1, prob + 1.96 * stdDev) ]; allResults.push({ label: pred.label || this.options.categories[pred.index || 0], prob, confidence, uncertainty, confidenceInterval, }); } } return allResults; } /** * Estimate uncertainty using weight sampling */ _estimateUncertainty(x) { if (this.weightSamples.length === 0) { return 0.5; // Default uncertainty } // Get predictions from multiple weight samples const predictions = []; for (const weights of this.weightSamples) { // Simplified: use variance in predictions as uncertainty measure // In practice, you'd compute actual predictions with sampled weights const pred = this._predictWithWeights(x, weights); predictions.push(pred); } // Compute variance as uncertainty measure const mean = predictions.reduce((a, b) => a + b, 0) / predictions.length; const variance = predictions.reduce((sum, p) => sum + Math.pow(p - mean, 2), 0) / predictions.length; // Normalize to [0, 1] return Math.min(1, variance); } /** * Predict with specific weight matrix (simplified) */ _predictWithWeights(x, weights) { // Simplified prediction - in practice, you'd use the actual ELM forward pass // This is a placeholder for uncertainty estimation return 0.5; } /** * Sample weight matrices for uncertainty estimation */ _sampleWeights() { const model = this.elm.model; if (!model || !model.W) return; const baseWeights = model.W; this.weightSamples = []; // Sample weights by adding Gaussian noise for (let s = 0; s < this.options.posteriorSamples; s++) { const sampled = []; for (let i = 0; i < baseWeights.length; i++) { sampled[i] = []; for (let j = 0; j < baseWeights[i].length; j++) { // Sample from posterior (Gaussian around base weight) const noise = this._gaussianRandom(0, this.options.priorVariance); sampled[i][j] = baseWeights[i][j] + noise; } } this.weightSamples.push(sampled); } } _gaussianRandom(mean, variance) { // Box-Muller transform const u1 = Math.random(); const u2 = Math.random(); const z0 = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2); return mean + z0 * Math.sqrt(variance); } } //# sourceMappingURL=variational-elm.js.map