@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
141 lines • 5.37 kB
JavaScript
// ensemble-kernel-elm.ts — Ensemble Kernel ELM
// Multiple KELM models with different kernels, voting/weighted combination
import { KernelELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Ensemble Kernel ELM
* Features:
* - Multiple KELM models with different kernels
* - Voting/weighted combination
* - Diversity promotion
* - Robust predictions
*/
export class EnsembleKernelELM {
constructor(options) {
this.models = [];
this.trained = false;
requireLicense(); // Premium feature - requires valid license
this.categories = options.categories;
// Default kernels if not provided
const defaultKernels = options.kernels || [
{ type: 'rbf', gamma: 1.0, weight: 1.0 },
{ type: 'polynomial', degree: 2, coef0: 0, weight: 1.0 },
{ type: 'linear', weight: 1.0 },
];
this.options = {
categories: options.categories,
kernels: defaultKernels.map(k => ({ ...k, weight: k.weight ?? 1.0 })),
votingType: options.votingType ?? 'weighted',
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
// Initialize models for each kernel
for (const kernel of this.options.kernels) {
const kelm = new KernelELM({
useTokenizer: this.options.useTokenizer ? true : undefined,
categories: this.options.categories,
maxLen: this.options.maxLen,
kernel: kernel.type,
gamma: kernel.gamma,
degree: kernel.degree,
coef0: kernel.coef0,
});
this.models.push(kelm);
}
}
/**
* Train ensemble
*/
train(X, y) {
// Prepare labels
const labelIndices = y.map(label => typeof label === 'number'
? label
: this.options.categories.indexOf(label));
// Train each model
for (const model of this.models) {
model.setCategories?.(this.options.categories);
model.trainFromData?.(X, labelIndices);
}
this.trained = true;
}
/**
* Predict with ensemble voting
*/
predict(X, topK = 3) {
if (!this.trained) {
throw new Error('Model must be trained before prediction');
}
const XArray = Array.isArray(X[0]) ? X : [X];
const allResults = [];
for (const x of XArray) {
// Get predictions from all models
const modelPredictions = [];
for (const model of this.models) {
const preds = model.predictFromVector?.([x], topK) || [];
modelPredictions.push(preds.map((p) => ({
label: p.label || this.options.categories[p.index || 0],
prob: p.prob || 0,
index: p.index || 0,
})));
}
// Combine predictions
const combined = this._combinePredictions(modelPredictions, topK);
allResults.push(...combined);
}
return allResults;
}
/**
* Combine predictions from multiple models
*/
_combinePredictions(modelPredictions, topK) {
// Aggregate predictions by label
const labelScores = new Map();
for (let modelIdx = 0; modelIdx < modelPredictions.length; modelIdx++) {
const kernel = this.options.kernels[modelIdx];
const weight = kernel.weight;
for (const pred of modelPredictions[modelIdx]) {
if (!labelScores.has(pred.label)) {
labelScores.set(pred.label, { prob: 0, votes: 0, weight: 0 });
}
const score = labelScores.get(pred.label);
if (this.options.votingType === 'majority') {
score.votes += 1;
}
else if (this.options.votingType === 'weighted') {
score.prob += pred.prob * weight;
score.weight += weight;
score.votes += 1;
}
else if (this.options.votingType === 'average') {
score.prob += pred.prob;
score.votes += 1;
}
}
}
// Normalize and sort
const results = [];
for (const [label, score] of labelScores) {
let finalProb;
if (this.options.votingType === 'majority') {
finalProb = score.votes / this.models.length;
}
else if (this.options.votingType === 'weighted') {
finalProb = score.weight > 0 ? score.prob / score.weight : 0;
}
else {
finalProb = score.votes > 0 ? score.prob / score.votes : 0;
}
results.push({
label,
prob: finalProb,
votes: score.votes,
confidence: finalProb * (score.votes / this.models.length),
});
}
// Sort by probability and return top K
results.sort((a, b) => b.prob - a.prob);
return results.slice(0, topK);
}
}
//# sourceMappingURL=ensemble-kernel-elm.js.map