@astermind/astermind-premium
Version:
Astermind Premium - Premium ML Toolkit
261 lines • 9.61 kB
JavaScript
// adaptive-online-elm.ts — Adaptive Online ELM with dynamic hidden unit adjustment
// Adjusts hidden units dynamically based on data complexity
// Import OnlineELM directly - now that we're using ES modules, this works!
import { OnlineELM } from '@astermind/astermind-elm';
import { requireLicense } from '../core/license.js';
/**
* Adaptive Online ELM that dynamically adjusts hidden units
* Features:
* - Grows hidden units when error is high
* - Shrinks hidden units when performance is stable
* - Maintains efficiency while adapting to data complexity
*/
export class AdaptiveOnlineELM {
constructor(options) {
this.elm = null;
this.trained = false;
this.errorHistory = [];
this.performanceHistory = [];
requireLicense(); // Premium feature - requires valid license
this.categories = options.categories;
this.options = {
categories: options.categories,
initialHiddenUnits: options.initialHiddenUnits ?? 128,
minHiddenUnits: options.minHiddenUnits ?? 32,
maxHiddenUnits: options.maxHiddenUnits ?? 1024,
growthThreshold: options.growthThreshold ?? 0.3,
shrinkThreshold: options.shrinkThreshold ?? 0.1,
growthFactor: options.growthFactor ?? 1.5,
shrinkFactor: options.shrinkFactor ?? 0.8,
activation: options.activation ?? 'relu',
maxLen: options.maxLen ?? 100,
useTokenizer: options.useTokenizer ?? true,
};
this.currentHiddenUnits = this.options.initialHiddenUnits;
this._initializeELM();
}
/**
* Initialize or reinitialize ELM with current hidden units
*/
_initializeELM(inputDim) {
// inputDim must be provided if elm is null or needs reinitialization
if (inputDim === undefined && this.elm && typeof this.elm.inputDim === 'number') {
inputDim = this.elm.inputDim;
}
if (inputDim === undefined) {
// Can't initialize without inputDim
return;
}
this.elm = new OnlineELM({
inputDim: inputDim,
outputDim: this.categories.length,
hiddenUnits: this.currentHiddenUnits,
activation: this.options.activation,
});
}
/**
* Train with batch data
*/
fit(X, y) {
// Convert to one-hot if needed
const oneHotY = this._toOneHot(y);
// Initialize or reinitialize if needed
if (!this.elm || (this.elm && typeof this.elm.inputDim === 'number' && this.elm.inputDim === 0)) {
if (X.length > 0) {
this._initializeELM(X[0].length);
}
}
if (!this.elm) {
throw new Error('Failed to initialize ELM model');
}
// Initial training with OnlineELM
if (this.elm) {
this.elm.fit(X, oneHotY);
}
// Evaluate and potentially adjust
const error = this._evaluateError(X, oneHotY);
this.errorHistory.push(error);
// Adaptive adjustment (may reinitialize ELM)
this._adaptHiddenUnits(error);
this.trained = true;
}
/**
* Incremental update with adaptive adjustment
*/
update(x, y) {
if (!this.trained || !this.elm) {
throw new Error('Model must be initially trained with fit() before incremental updates');
}
const oneHotY = Array.isArray(y)
? y
: (() => {
const oh = new Array(this.categories.length).fill(0);
oh[y] = 1;
return oh;
})();
// Update model with OnlineELM
if (this.elm) {
this.elm.update([x], [oneHotY]);
}
else {
throw new Error('Model not initialized');
}
// Evaluate recent performance
const recentError = this._evaluateRecentError();
this.errorHistory.push(recentError);
// Keep history limited
if (this.errorHistory.length > 100) {
this.errorHistory.shift();
}
// Adaptive adjustment (may reinitialize ELM)
this._adaptHiddenUnits(recentError);
}
/**
* Predict with adaptive model
*/
predict(x, topK = 3) {
if (!this.trained || !this.elm) {
throw new Error('Model must be trained before prediction');
}
const XArray = Array.isArray(x[0]) ? x : [x];
const results = [];
for (const xi of XArray) {
if (!this.elm)
continue;
const predVec = this.elm.predictLogitsFromVector(xi);
if (!predVec)
continue;
// Convert to probabilities
const probs = this._softmax(Array.from(predVec));
// Get top-K
const indexed = [];
for (let idx = 0; idx < probs.length; idx++) {
indexed.push({
label: this.categories[idx],
prob: probs[idx],
index: idx,
});
}
indexed.sort((a, b) => b.prob - a.prob);
for (let i = 0; i < Math.min(topK, indexed.length); i++) {
results.push({
label: indexed[i].label,
prob: indexed[i].prob,
});
}
}
return results;
}
/**
* Adapt hidden units based on error
*/
_adaptHiddenUnits(currentError) {
if (this.errorHistory.length < 5)
return; // Need some history
const avgError = this.errorHistory.slice(-10).reduce((a, b) => a + b, 0) / Math.min(10, this.errorHistory.length);
const recentError = this.errorHistory.slice(-3).reduce((a, b) => a + b, 0) / Math.min(3, this.errorHistory.length);
// Grow if error is high
if (recentError > this.options.growthThreshold &&
this.currentHiddenUnits < this.options.maxHiddenUnits) {
const newUnits = Math.min(this.options.maxHiddenUnits, Math.floor(this.currentHiddenUnits * this.options.growthFactor));
if (newUnits > this.currentHiddenUnits) {
const oldInputDim = this.elm && typeof this.elm.inputDim === 'number'
? this.elm.inputDim
: undefined;
this.currentHiddenUnits = newUnits;
if (oldInputDim !== undefined) {
this._initializeELM(oldInputDim);
}
// Note: In practice, you'd want to store recent data for retraining
// For now, model will need to be retrained
}
}
// Shrink if error is low and stable
if (recentError < this.options.shrinkThreshold &&
avgError < this.options.shrinkThreshold &&
this.currentHiddenUnits > this.options.minHiddenUnits) {
const newUnits = Math.max(this.options.minHiddenUnits, Math.floor(this.currentHiddenUnits * this.options.shrinkFactor));
if (newUnits < this.currentHiddenUnits) {
const oldInputDim = this.elm && typeof this.elm.inputDim === 'number'
? this.elm.inputDim
: undefined;
this.currentHiddenUnits = newUnits;
if (oldInputDim !== undefined) {
this._initializeELM(oldInputDim);
}
}
}
}
/**
* Evaluate error on data
*/
_evaluateError(X, y) {
if (!this.elm)
return 1.0;
let totalError = 0;
let count = 0;
for (let i = 0; i < Math.min(100, X.length); i++) {
const pred = this.elm.transform?.([X[i]]) || this.elm.predict?.([X[i]]);
const predVec = Array.isArray(pred) ? pred[0] : pred;
if (!predVec)
continue;
const trueIdx = this._argmax(y[i]);
const predIdx = this._argmax(Array.from(predVec));
if (trueIdx !== predIdx)
totalError++;
count++;
}
return count > 0 ? totalError / count : 1.0;
}
/**
* Evaluate recent error (for incremental updates)
*/
_evaluateRecentError() {
// Use last few predictions for error estimate
// In practice, you'd track actual errors
if (this.errorHistory.length === 0)
return 0.5;
return this.errorHistory[this.errorHistory.length - 1];
}
_toOneHot(y) {
if (Array.isArray(y[0])) {
return y;
}
const labels = y;
return labels.map((label) => {
const oneHot = new Array(this.categories.length).fill(0);
oneHot[label] = 1;
return oneHot;
});
}
_softmax(logits) {
const max = Math.max(...logits);
const exp = logits.map(x => Math.exp(x - max));
const sum = exp.reduce((a, b) => a + b, 0);
return exp.map(x => x / sum);
}
_argmax(arr) {
let maxIdx = 0;
let maxVal = arr[0] || 0;
for (let i = 1; i < arr.length; i++) {
if ((arr[i] || 0) > maxVal) {
maxVal = arr[i] || 0;
maxIdx = i;
}
}
return maxIdx;
}
/**
* Get current number of hidden units
*/
getHiddenUnits() {
return this.currentHiddenUnits;
}
/**
* Get error history
*/
getErrorHistory() {
return [...this.errorHistory];
}
}
//# sourceMappingURL=adaptive-online-elm.js.map