@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
171 lines • 6.08 kB
JavaScript
// robust-kernel-elm.ts — Robust Kernel ELM
// Outlier-resistant kernels with robust loss functions
import { KernelELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Robust Kernel ELM with outlier resistance
* Features:
* - Outlier-resistant kernels
* - Robust loss functions
* - Noise-tolerant learning
* - Outlier detection
*/
export class RobustKernelELM {
constructor(options) {
this.outlierIndices = new Set();
this.trained = false;
requireLicense(); // Premium feature - requires valid license
this.categories = options.categories;
this.options = {
categories: options.categories,
kernelType: options.kernelType ?? 'rbf',
robustLoss: options.robustLoss ?? 'huber',
outlierThreshold: options.outlierThreshold ?? 2.0,
gamma: options.gamma ?? 1.0,
degree: options.degree ?? 2,
coef0: options.coef0 ?? 0,
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
this.kelm = new KernelELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
categories: this.options.categories,
maxLen: this.options.maxLen,
kernel: this.options.kernelType,
gamma: this.options.gamma,
degree: this.options.degree,
coef0: this.options.coef0,
});
}
/**
* Train with robust loss
*/
train(X, y) {
// Prepare labels
const labelIndices = y.map(label => typeof label === 'number'
? label
: this.options.categories.indexOf(label));
// Detect outliers
this._detectOutliers(X);
// Filter outliers for training (or use robust weighting)
const filteredX = [];
const filteredY = [];
for (let i = 0; i < X.length; i++) {
if (!this.outlierIndices.has(i)) {
filteredX.push(X[i]);
filteredY.push(labelIndices[i]);
}
}
// Train on filtered data
this.kelm.setCategories?.(this.options.categories);
this.kelm.trainFromData?.(filteredX.length > 0 ? filteredX : X, filteredY.length > 0 ? filteredY : labelIndices);
this.trained = true;
}
/**
* Detect outliers using statistical methods
*/
_detectOutliers(X) {
this.outlierIndices.clear();
if (X.length === 0)
return;
// Compute mean and std for each dimension
const dim = X[0].length;
const means = new Array(dim).fill(0);
const stds = new Array(dim).fill(0);
// Compute means
for (const x of X) {
for (let i = 0; i < dim; i++) {
means[i] += x[i] || 0;
}
}
for (let i = 0; i < dim; i++) {
means[i] /= X.length;
}
// Compute standard deviations
for (const x of X) {
for (let i = 0; i < dim; i++) {
stds[i] += Math.pow((x[i] || 0) - means[i], 2);
}
}
for (let i = 0; i < dim; i++) {
stds[i] = Math.sqrt(stds[i] / X.length);
}
// Detect outliers (points far from mean)
for (let i = 0; i < X.length; i++) {
const x = X[i];
let maxZScore = 0;
for (let j = 0; j < dim; j++) {
if (stds[j] > 0) {
const zScore = Math.abs((x[j] || 0) - means[j]) / stds[j];
maxZScore = Math.max(maxZScore, zScore);
}
}
if (maxZScore > this.options.outlierThreshold) {
this.outlierIndices.add(i);
}
}
}
/**
* Apply robust loss function
*/
_robustLoss(error) {
if (this.options.robustLoss === 'huber') {
const delta = 1.0;
if (Math.abs(error) <= delta) {
return 0.5 * error * error;
}
else {
return delta * (Math.abs(error) - 0.5 * delta);
}
}
else if (this.options.robustLoss === 'hinge') {
return Math.max(0, 1 - error);
}
else if (this.options.robustLoss === 'epsilon-insensitive') {
const epsilon = 0.1;
return Math.max(0, Math.abs(error) - epsilon);
}
return error * error; // Default: squared loss
}
/**
* Predict with outlier detection
*/
predict(X, topK = 3) {
if (!this.trained) {
throw new Error('Model must be trained before prediction');
}
const XArray = Array.isArray(X[0]) ? X : [X];
const results = [];
for (const x of XArray) {
// Check if input is outlier
const isOutlier = this._isOutlier(x);
// Get prediction
const preds = this.kelm.predictFromVector?.([x], topK) || [];
for (const pred of preds.slice(0, topK)) {
const prob = pred.prob || 0;
const robustness = isOutlier ? 0.5 : 1.0; // Lower robustness for outliers
results.push({
label: pred.label || this.options.categories[pred.index || 0],
prob,
isOutlier,
robustness,
});
}
}
return results;
}
/**
* Check if a point is an outlier
*/
_isOutlier(x) {
// Simplified outlier check (in practice, use trained model statistics)
const mean = x.reduce((a, b) => a + b, 0) / x.length;
const std = Math.sqrt(x.reduce((sum, v) => sum + Math.pow(v - mean, 2), 0) / x.length);
if (std === 0)
return false;
const maxZScore = Math.max(...x.map(v => Math.abs((v - mean) / std)));
return maxZScore > this.options.outlierThreshold;
}
}
//# sourceMappingURL=robust-kernel-elm.js.map