@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
145 lines • 5.72 kB
JavaScript
// variational-elm.ts — Variational ELM with uncertainty estimation
// Probabilistic ELM with Bayesian inference and confidence intervals
import { ELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Variational ELM with uncertainty estimation
* Features:
* - Probabilistic predictions with uncertainty
* - Bayesian inference
* - Confidence intervals
* - Robust predictions with uncertainty quantification
*/
export class VariationalELM {
constructor(options) {
this.weightSamples = []; // Sampled weight matrices
this.trained = false;
requireLicense(); // Premium feature - requires valid license
this.categories = options.categories;
this.options = {
categories: options.categories,
hiddenUnits: options.hiddenUnits ?? 256,
priorVariance: options.priorVariance ?? 1.0,
posteriorSamples: options.posteriorSamples ?? 10,
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
this.elm = new ELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
hiddenUnits: this.options.hiddenUnits,
categories: this.options.categories,
maxLen: this.options.maxLen,
activation: this.options.activation,
});
}
/**
* Train variational ELM
*/
train(X, y) {
// Prepare labels
const labelIndices = y.map(label => typeof label === 'number'
? label
: this.options.categories.indexOf(label));
// Train base ELM
this.elm.setCategories?.(this.options.categories);
this.elm.trainFromData?.(X, labelIndices);
// Sample weights for uncertainty estimation
this._sampleWeights();
this.trained = true;
}
/**
* Predict with uncertainty estimation
*/
predict(X, topK = 3, includeUncertainty = true) {
if (!this.trained) {
throw new Error('Model must be trained before prediction');
}
const XArray = Array.isArray(X[0]) ? X : [X];
const allResults = [];
for (const x of XArray) {
// Get base prediction
const basePreds = this.elm.predictFromVector?.([x], topK) || [];
// Estimate uncertainty
const uncertainty = includeUncertainty ? this._estimateUncertainty(x) : 0.5;
for (const pred of basePreds.slice(0, topK)) {
const prob = pred.prob || 0;
const confidence = Math.max(0, Math.min(1, 1 - uncertainty));
// Compute confidence interval
const stdDev = Math.sqrt(uncertainty * prob * (1 - prob));
const confidenceInterval = [
Math.max(0, prob - 1.96 * stdDev),
Math.min(1, prob + 1.96 * stdDev)
];
allResults.push({
label: pred.label || this.options.categories[pred.index || 0],
prob,
confidence,
uncertainty,
confidenceInterval,
});
}
}
return allResults;
}
/**
* Estimate uncertainty using weight sampling
*/
_estimateUncertainty(x) {
if (this.weightSamples.length === 0) {
return 0.5; // Default uncertainty
}
// Get predictions from multiple weight samples
const predictions = [];
for (const weights of this.weightSamples) {
// Simplified: use variance in predictions as uncertainty measure
// In practice, you'd compute actual predictions with sampled weights
const pred = this._predictWithWeights(x, weights);
predictions.push(pred);
}
// Compute variance as uncertainty measure
const mean = predictions.reduce((a, b) => a + b, 0) / predictions.length;
const variance = predictions.reduce((sum, p) => sum + Math.pow(p - mean, 2), 0) / predictions.length;
// Normalize to [0, 1]
return Math.min(1, variance);
}
/**
* Predict with specific weight matrix (simplified)
*/
_predictWithWeights(x, weights) {
// Simplified prediction - in practice, you'd use the actual ELM forward pass
// This is a placeholder for uncertainty estimation
return 0.5;
}
/**
* Sample weight matrices for uncertainty estimation
*/
_sampleWeights() {
const model = this.elm.model;
if (!model || !model.W)
return;
const baseWeights = model.W;
this.weightSamples = [];
// Sample weights by adding Gaussian noise
for (let s = 0; s < this.options.posteriorSamples; s++) {
const sampled = [];
for (let i = 0; i < baseWeights.length; i++) {
sampled[i] = [];
for (let j = 0; j < baseWeights[i].length; j++) {
// Sample from posterior (Gaussian around base weight)
const noise = this._gaussianRandom(0, this.options.priorVariance);
sampled[i][j] = baseWeights[i][j] + noise;
}
}
this.weightSamples.push(sampled);
}
}
_gaussianRandom(mean, variance) {
// Box-Muller transform
const u1 = Math.random();
const u2 = Math.random();
const z0 = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
return mean + z0 * Math.sqrt(variance);
}
}
//# sourceMappingURL=variational-elm.js.map